Unverified Commit b48354c5 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Fix inconsistent behavior of conversation_id not found (#12299)

parent 29195aaa
...@@ -5,6 +5,7 @@ use std::collections::HashMap; ...@@ -5,6 +5,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use validator::Validate;
// Import shared types from common module // Import shared types from common module
use super::common::{ use super::common::{
...@@ -439,7 +440,7 @@ fn default_top_p() -> Option<f32> { ...@@ -439,7 +440,7 @@ fn default_top_p() -> Option<f32> {
// Request/Response Types // Request/Response Types
// ============================================================================ // ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize, Validate)]
pub struct ResponsesRequest { pub struct ResponsesRequest {
/// Run the request in the background /// Run the request in the background
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -474,6 +475,7 @@ pub struct ResponsesRequest { ...@@ -474,6 +475,7 @@ pub struct ResponsesRequest {
/// Optional conversation id to persist input/output as items /// Optional conversation id to persist input/output as items
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_conversation_id"))]
pub conversation: Option<String>, pub conversation: Option<String>,
/// Whether to enable parallel tool calls /// Whether to enable parallel tool calls
...@@ -683,6 +685,33 @@ impl GenerationRequest for ResponsesRequest { ...@@ -683,6 +685,33 @@ impl GenerationRequest for ResponsesRequest {
} }
} }
/// Validate conversation ID format
pub fn validate_conversation_id(conv_id: &str) -> Result<(), validator::ValidationError> {
if !conv_id.starts_with("conv_") {
let mut error = validator::ValidationError::new("invalid_conversation_id");
error.message = Some(std::borrow::Cow::Owned(format!(
"Invalid 'conversation': '{}'. Expected an ID that begins with 'conv_'.",
conv_id
)));
return Err(error);
}
// Check if the conversation ID contains only valid characters
let is_valid = conv_id
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-');
if !is_valid {
let mut error = validator::ValidationError::new("invalid_conversation_id");
error.message = Some(std::borrow::Cow::Owned(format!(
"Invalid 'conversation': '{}'. Expected an ID that contains letters, numbers, underscores, or dashes, but this value contained additional characters.",
conv_id
)));
return Err(error);
}
Ok(())
}
/// Normalize a SimpleInputMessage to a proper Message item /// Normalize a SimpleInputMessage to a proper Message item
/// ///
/// This helper converts SimpleInputMessage (which can have flexible content) /// This helper converts SimpleInputMessage (which can have flexible content)
......
...@@ -48,6 +48,7 @@ use tokio::sync::{mpsc, RwLock}; ...@@ -48,6 +48,7 @@ use tokio::sync::{mpsc, RwLock};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
use uuid::Uuid; use uuid::Uuid;
use validator::Validate;
use super::{ use super::{
conversions, conversions,
...@@ -109,7 +110,32 @@ pub async fn route_responses( ...@@ -109,7 +110,32 @@ pub async fn route_responses(
} }
} }
// 1. Validate mutually exclusive parameters // 1. Validate request (includes conversation ID format)
if let Err(validation_errors) = request.validate() {
// Extract the first error message for conversation field
let error_message = validation_errors
.field_errors()
.get("conversation")
.and_then(|errors| errors.first())
.and_then(|error| error.message.as_ref())
.map(|msg| msg.to_string())
.unwrap_or_else(|| "Invalid request parameters".to_string());
return (
StatusCode::BAD_REQUEST,
axum::Json(json!({
"error": {
"message": error_message,
"type": "invalid_request_error",
"param": "conversation",
"code": "invalid_value"
}
})),
)
.into_response();
}
// 2. Validate mutually exclusive parameters
if request.previous_response_id.is_some() && request.conversation.is_some() { if request.previous_response_id.is_some() && request.conversation.is_some() {
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
...@@ -125,7 +151,7 @@ pub async fn route_responses( ...@@ -125,7 +151,7 @@ pub async fn route_responses(
.into_response(); .into_response();
} }
// 2. Check for incompatible parameter combinations // 3. Check for incompatible parameter combinations
let is_streaming = request.stream.unwrap_or(false); let is_streaming = request.stream.unwrap_or(false);
let is_background = request.background.unwrap_or(false); let is_background = request.background.unwrap_or(false);
...@@ -144,7 +170,7 @@ pub async fn route_responses( ...@@ -144,7 +170,7 @@ pub async fn route_responses(
.into_response(); .into_response();
} }
// 3. Route based on execution mode // 4. Route based on execution mode
if is_streaming { if is_streaming {
route_responses_streaming(ctx, request, headers, model_id).await route_responses_streaming(ctx, request, headers, model_id).await
} else if is_background { } else if is_background {
...@@ -928,33 +954,23 @@ async fn load_conversation_history( ...@@ -928,33 +954,23 @@ async fn load_conversation_history(
if let Some(ref conv_id_str) = request.conversation { if let Some(ref conv_id_str) = request.conversation {
let conv_id = ConversationId::from(conv_id_str.as_str()); let conv_id = ConversationId::from(conv_id_str.as_str());
// Auto-create conversation if it doesn't exist (OpenAI behavior) // Check if conversation exists - return error if not found
if let Ok(None) = ctx.conversation_storage.get_conversation(&conv_id).await { let conversation = ctx
debug!( .conversation_storage
"Creating new conversation with user-provided ID: {}", .get_conversation(&conv_id)
.await
.map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to check conversation: {}",
e
))
})?;
if conversation.is_none() {
return Err(crate::routers::grpc::utils::bad_request_error(format!(
"Conversation '{}' not found. Please create the conversation first using the conversations API.",
conv_id_str conv_id_str
); )));
// Convert HashMap to JsonMap for metadata
let metadata = request.metadata.as_ref().map(|m| {
m.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<serde_json::Map<String, serde_json::Value>>()
});
let new_conv = crate::data_connector::NewConversation {
id: Some(conv_id.clone()), // Use user-provided conversation ID
metadata,
};
ctx.conversation_storage
.create_conversation(new_conv)
.await
.map_err(|e| {
crate::routers::grpc::utils::internal_error_message(format!(
"Failed to create conversation: {}",
e
))
})?;
} }
// Load conversation history // Load conversation history
......
...@@ -6,3 +6,4 @@ mod chat_completion; ...@@ -6,3 +6,4 @@ mod chat_completion;
mod chat_message; mod chat_message;
mod embedding; mod embedding;
mod rerank; mod rerank;
mod responses;
use sglang_router_rs::protocols::responses::{ResponseInput, ResponsesRequest};
use validator::Validate;
/// Test that valid conversation IDs pass validation
#[test]
fn test_validate_conversation_id_valid() {
let valid_ids = vec![
"conv_123",
"conv_test-123_abc",
"conv_ABC_123",
"conv_my_conversation_123",
"conv_456",
"conv_test123",
];
for id in valid_ids {
let request = ResponsesRequest {
conversation: Some(id.to_string()),
input: ResponseInput::Text("test".to_string()),
..Default::default()
};
assert!(
request.validate().is_ok(),
"Expected '{}' to be valid, but got error: {:?}",
id,
request.validate().err()
);
}
}
/// Test that invalid conversation IDs fail validation
#[test]
fn test_validate_conversation_id_invalid() {
let invalid_ids = vec![
// Missing 'conv_' prefix
"test-conv-streaming",
"conversation-456",
"my_conversation_123",
"ABC123",
"test_123_conv",
"conv123", // missing underscore
// Invalid characters
"conv_.test", // contains dot
"conv_ test", // contains space
"conv_@test", // contains @
"conv_/test", // contains /
"conv_\\test", // contains backslash
"conv_:test", // contains colon
"conv_;test", // contains semicolon
"conv_,test", // contains comma
"conv_+test", // contains plus
"conv_=test", // contains equals
"conv_[test]", // contains brackets
"conv_{test}", // contains braces
"conv_(test)", // contains parentheses
"conv_!test", // contains exclamation
"conv_?test", // contains question mark
"conv_#test", // contains hash
"conv_$test", // contains dollar sign
"conv_%test", // contains percent
"conv_&test", // contains ampersand
"conv_*test", // contains asterisk
"conv_ test-123", // contains space
];
for id in invalid_ids {
let request = ResponsesRequest {
conversation: Some(id.to_string()),
input: ResponseInput::Text("test".to_string()),
..Default::default()
};
let result = request.validate();
assert!(
result.is_err(),
"Expected '{}' to be invalid, but validation passed",
id
);
// Verify error is for conversation field
if let Err(errors) = result {
let field_errors = errors.field_errors();
let conversation_errors = field_errors.get("conversation");
assert!(
conversation_errors.is_some(),
"Expected error for 'conversation' field, but got errors for: {:?}",
field_errors.keys()
);
let error_msg = conversation_errors
.and_then(|errs| errs.first())
.and_then(|err| err.message.as_ref())
.map(|msg| msg.to_string());
assert!(
error_msg.is_some(),
"Expected error message for conversation field"
);
let msg = error_msg.unwrap();
assert!(
msg.contains("Invalid 'conversation'"),
"Error message should mention 'conversation', got: {}",
msg
);
assert!(
msg.contains(id),
"Error message should include the invalid ID '{}', got: {}",
id,
msg
);
}
}
}
/// Test that None conversation ID is valid
#[test]
fn test_validate_conversation_id_none() {
let request = ResponsesRequest {
conversation: None,
input: ResponseInput::Text("test".to_string()),
..Default::default()
};
assert!(
request.validate().is_ok(),
"Request with no conversation ID should be valid"
);
}
/// Test the exact error format matches OpenAI's error message for invalid characters
#[test]
fn test_validate_conversation_id_error_message_format() {
let invalid_id = "conv_.test-conv-streaming";
let request = ResponsesRequest {
conversation: Some(invalid_id.to_string()),
input: ResponseInput::Text("test".to_string()),
..Default::default()
};
let result = request.validate();
assert!(result.is_err());
if let Err(errors) = result {
let error_msg = errors
.field_errors()
.get("conversation")
.and_then(|errs| errs.first())
.and_then(|err| err.message.as_ref())
.map(|msg| msg.to_string())
.unwrap();
// Verify the error message matches OpenAI's format
assert!(
error_msg.starts_with("Invalid 'conversation':"),
"Error should start with \"Invalid 'conversation':\""
);
assert!(
error_msg.contains("letters, numbers, underscores, or dashes"),
"Error should mention valid characters"
);
assert!(
error_msg.contains(invalid_id),
"Error should include the invalid conversation ID"
);
}
}
/// Test the exact error format for missing 'conv_' prefix
#[test]
fn test_validate_conversation_id_missing_prefix() {
let invalid_id = "test-conv-streaming";
let request = ResponsesRequest {
conversation: Some(invalid_id.to_string()),
input: ResponseInput::Text("test".to_string()),
..Default::default()
};
let result = request.validate();
assert!(result.is_err());
if let Err(errors) = result {
let error_msg = errors
.field_errors()
.get("conversation")
.and_then(|errs| errs.first())
.and_then(|err| err.message.as_ref())
.map(|msg| msg.to_string())
.unwrap();
// Verify the error message matches OpenAI's format
assert!(
error_msg.starts_with("Invalid 'conversation':"),
"Error should start with \"Invalid 'conversation':\""
);
assert!(
error_msg.contains("begins with 'conv_'"),
"Error should mention the required prefix, got: {}",
error_msg
);
assert!(
error_msg.contains(invalid_id),
"Error should include the invalid conversation ID"
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment