Unverified Commit 93ada899 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

feat: enable HTTP completion endpoint to accept arrays of prompts and generate...


feat: enable HTTP completion endpoint to accept arrays of prompts and generate multiple completions per prompt (#3953)
Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
parent 6bccf099
...@@ -318,11 +318,33 @@ async fn completions( ...@@ -318,11 +318,33 @@ async fn completions(
request: Context<NvCreateCompletionRequest>, request: Context<NvCreateCompletionRequest>,
stream_handle: ConnectionHandle, stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> { ) -> Result<Response, ErrorResponse> {
use crate::protocols::openai::completions::get_prompt_batch_size;
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
validate_completion_fields_generic(&request)?; validate_completion_fields_generic(&request)?;
// Detect batch prompts
let batch_size = get_prompt_batch_size(&request.inner.prompt);
let n = request.inner.n.unwrap_or(1);
// If single prompt or single-element batch, use original flow
if batch_size == 1 {
return completions_single(state, request, stream_handle).await;
}
// Batch processing: handle multiple prompts
completions_batch(state, request, stream_handle, batch_size, n).await
}
/// Handle single prompt completions (original logic)
#[tracing::instrument(skip_all)]
async fn completions_single(
state: Arc<service_v2::State>,
request: Context<NvCreateCompletionRequest>,
stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
let request_id = request.id().to_string(); let request_id = request.id().to_string();
// todo - decide on default // todo - decide on default
...@@ -433,6 +455,162 @@ async fn completions( ...@@ -433,6 +455,162 @@ async fn completions(
} }
} }
/// Handle batch prompt completions (multiple prompts with n choices each)
#[tracing::instrument(skip_all)]
async fn completions_batch(
state: Arc<service_v2::State>,
request: Context<NvCreateCompletionRequest>,
stream_handle: ConnectionHandle,
batch_size: usize,
n: u8,
) -> Result<Response, ErrorResponse> {
use crate::protocols::openai::completions::extract_single_prompt;
use futures::stream::{self, StreamExt};
let request_id = request.id().to_string();
let streaming = request.inner.stream.unwrap_or(false);
let model = request.inner.model.clone();
// Create http_queue_guard early - tracks time waiting to be processed
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let engine = state
.manager()
.get_completions_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
// prepare to process any annotations
let annotations = request.annotations();
// Create inflight_guard before calling engine to ensure errors are counted
let mut inflight_guard =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Completions, streaming);
// Generate streams for each prompt in the batch
let mut all_streams = Vec::new();
let mut first_ctx = None;
for prompt_idx in 0..batch_size {
// Extract single prompt at this index
let single_prompt = extract_single_prompt(&request.inner.prompt, prompt_idx);
// Create a new request with this single prompt
let mut single_request = request.content().clone();
single_request.inner.prompt = single_prompt;
// Generate unique request_id for each prompt: original_id-{prompt_idx}
let unique_request_id = format!("{}-{}", request.id(), prompt_idx);
let single_request_context = Context::with_id(single_request, unique_request_id);
// Generate stream for this prompt
let stream = engine
.generate(single_request_context)
.await
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
// Capture context from first stream
if first_ctx.is_none() {
first_ctx = Some(stream.context());
}
// Remap choice indices: choice.index += prompt_idx * n
let prompt_idx_u32 = prompt_idx as u32;
let n_u32 = n as u32;
let remapped_stream = stream.map(move |mut response| {
if let Some(ref mut data) = response.data {
for choice in &mut data.inner.choices {
choice.index += prompt_idx_u32 * n_u32;
}
}
response
});
all_streams.push(remapped_stream);
}
// Merge all streams
let merged_stream = stream::select_all(all_streams);
// capture the context to cancel the stream if the client disconnects
let ctx = first_ctx.expect("At least one stream should be generated");
let annotations_vec = annotations.map_or(Vec::new(), |annotations| {
annotations
.iter()
.filter_map(|annotation| {
if annotation == ANNOTATION_REQUEST_ID {
Annotated::<NvCreateCompletionResponse>::from_annotation(
ANNOTATION_REQUEST_ID,
&request_id,
)
.ok()
} else {
None
}
})
.collect::<Vec<_>>()
});
// apply any annotations to the front of the stream
let merged_stream = stream::iter(annotations_vec).chain(merged_stream);
if streaming {
// For streaming, we'll drop the http_queue_guard on the first token
let mut http_queue_guard = Some(http_queue_guard);
let stream = merged_stream.map(move |response| {
// Calls observe_response() on each token
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
});
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive() {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
}
Ok(sse_stream.into_response())
} else {
// Tap the stream to collect metrics for non-streaming requests without altering items
let mut http_queue_guard = Some(http_queue_guard);
let stream = merged_stream.inspect(move |response| {
// Calls observe_response() on each token - drops http_queue_guard on first token
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
.await
.map_err(|e| {
tracing::error!(
"Failed to fold completions stream for {}: {:?}",
request_id,
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold completions stream for {}: {:?}",
request_id, e
))
})?;
inflight_guard.mark_ok();
Ok(Json(response).into_response())
}
}
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn embeddings( async fn embeddings(
State(state): State<Arc<service_v2::State>>, State(state): State<Arc<service_v2::State>>,
......
...@@ -78,6 +78,39 @@ pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String { ...@@ -78,6 +78,39 @@ pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String {
} }
} }
/// Get the batch size from a prompt (1 for single prompts, array length for batch prompts)
pub fn get_prompt_batch_size(prompt: &dynamo_async_openai::types::Prompt) -> usize {
match prompt {
dynamo_async_openai::types::Prompt::String(_) => 1,
dynamo_async_openai::types::Prompt::IntegerArray(_) => 1,
dynamo_async_openai::types::Prompt::StringArray(arr) => arr.len(),
dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr.len(),
}
}
/// Extract a single prompt from a batch at the given index.
/// For single prompts, returns a clone regardless of index.
/// For batch prompts, returns the prompt at the specified index.
pub fn extract_single_prompt(
prompt: &dynamo_async_openai::types::Prompt,
index: usize,
) -> dynamo_async_openai::types::Prompt {
match prompt {
dynamo_async_openai::types::Prompt::String(s) => {
dynamo_async_openai::types::Prompt::String(s.clone())
}
dynamo_async_openai::types::Prompt::IntegerArray(arr) => {
dynamo_async_openai::types::Prompt::IntegerArray(arr.clone())
}
dynamo_async_openai::types::Prompt::StringArray(arr) => {
dynamo_async_openai::types::Prompt::String(arr[index].clone())
}
dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => {
dynamo_async_openai::types::Prompt::IntegerArray(arr[index].clone())
}
}
}
impl NvExtProvider for NvCreateCompletionRequest { impl NvExtProvider for NvCreateCompletionRequest {
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
...@@ -403,7 +436,11 @@ impl ValidateRequest for NvCreateCompletionRequest { ...@@ -403,7 +436,11 @@ impl ValidateRequest for NvCreateCompletionRequest {
validate::validate_top_k(self.get_top_k())?; validate::validate_top_k(self.get_top_k())?;
// Cross-field validation // Cross-field validation
validate::validate_n_with_temperature(self.inner.n, self.inner.temperature)?; validate::validate_n_with_temperature(self.inner.n, self.inner.temperature)?;
// total choices validation for completions batch requests
validate::validate_total_choices(
get_prompt_batch_size(&self.inner.prompt),
self.inner.n.unwrap_or(1),
)?;
Ok(()) Ok(())
} }
} }
...@@ -66,6 +66,9 @@ pub const MAX_N: u8 = 128; ...@@ -66,6 +66,9 @@ pub const MAX_N: u8 = 128;
/// Allowed range of values for `n` (number of choices) /// Allowed range of values for `n` (number of choices)
pub const N_RANGE: (u8, u8) = (MIN_N, MAX_N); pub const N_RANGE: (u8, u8) = (MIN_N, MAX_N);
/// Maximum allowed total number of choices (batch_size × n)
pub const MAX_TOTAL_CHOICES: usize = 128;
/// Minimum allowed value for OpenAI's `logit_bias` values /// Minimum allowed value for OpenAI's `logit_bias` values
pub const MIN_LOGIT_BIAS: f32 = -100.0; pub const MIN_LOGIT_BIAS: f32 = -100.0;
/// Maximum allowed value for OpenAI's `logit_bias` values /// Maximum allowed value for OpenAI's `logit_bias` values
...@@ -261,6 +264,21 @@ pub fn validate_n(n: Option<u8>) -> Result<(), anyhow::Error> { ...@@ -261,6 +264,21 @@ pub fn validate_n(n: Option<u8>) -> Result<(), anyhow::Error> {
Ok(()) Ok(())
} }
/// Validates total choices (batch_size × n) doesn't exceed maximum
pub fn validate_total_choices(batch_size: usize, n: u8) -> Result<(), anyhow::Error> {
let total_choices = batch_size * (n as usize);
if total_choices > MAX_TOTAL_CHOICES {
anyhow::bail!(
"Total choices (batch_size × n = {} × {} = {}) exceeds maximum of {}",
batch_size,
n,
total_choices,
MAX_TOTAL_CHOICES
);
}
Ok(())
}
/// Validates n and temperature interaction /// Validates n and temperature interaction
/// When n > 1, temperature must be > 0 to ensure diverse outputs /// When n > 1, temperature must be > 0 to ensure diverse outputs
pub fn validate_n_with_temperature( pub fn validate_n_with_temperature(
......
...@@ -118,3 +118,144 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> { ...@@ -118,3 +118,144 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> {
Ok(samples) Ok(samples)
} }
// ============================================================================
// Batch Prompt Tests
// ============================================================================
#[test]
fn test_batch_prompt_utilities() {
use dynamo_async_openai::types::Prompt;
use dynamo_llm::protocols::openai::completions::{
extract_single_prompt, get_prompt_batch_size,
};
// Test single string prompt
let single_string = Prompt::String("Hello, world!".to_string());
assert_eq!(get_prompt_batch_size(&single_string), 1);
assert_eq!(
extract_single_prompt(&single_string, 0),
Prompt::String("Hello, world!".to_string())
);
// Test single integer array prompt
let single_int = Prompt::IntegerArray(vec![1, 2, 3]);
assert_eq!(get_prompt_batch_size(&single_int), 1);
assert_eq!(
extract_single_prompt(&single_int, 0),
Prompt::IntegerArray(vec![1, 2, 3])
);
// Test string array prompt
let string_array = Prompt::StringArray(vec![
"First prompt".to_string(),
"Second prompt".to_string(),
"Third prompt".to_string(),
]);
assert_eq!(get_prompt_batch_size(&string_array), 3);
assert_eq!(
extract_single_prompt(&string_array, 0),
Prompt::String("First prompt".to_string())
);
assert_eq!(
extract_single_prompt(&string_array, 1),
Prompt::String("Second prompt".to_string())
);
assert_eq!(
extract_single_prompt(&string_array, 2),
Prompt::String("Third prompt".to_string())
);
// Test array of integer arrays
let int_array = Prompt::ArrayOfIntegerArray(vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]]);
assert_eq!(get_prompt_batch_size(&int_array), 3);
assert_eq!(
extract_single_prompt(&int_array, 0),
Prompt::IntegerArray(vec![1, 2, 3])
);
assert_eq!(
extract_single_prompt(&int_array, 1),
Prompt::IntegerArray(vec![4, 5])
);
assert_eq!(
extract_single_prompt(&int_array, 2),
Prompt::IntegerArray(vec![6, 7, 8, 9])
);
}
#[test]
fn test_total_choices_validation() {
use dynamo_llm::protocols::openai::validate::validate_total_choices;
// Valid cases
assert!(validate_total_choices(1, 1).is_ok());
assert!(validate_total_choices(10, 10).is_ok());
assert!(validate_total_choices(64, 2).is_ok());
assert!(validate_total_choices(128, 1).is_ok());
assert!(validate_total_choices(1, 128).is_ok());
// Edge case: exactly at the limit
assert!(validate_total_choices(128, 1).is_ok());
assert!(validate_total_choices(64, 2).is_ok());
// Invalid cases: exceeds limit
assert!(validate_total_choices(129, 1).is_err());
assert!(validate_total_choices(65, 2).is_err());
assert!(validate_total_choices(100, 2).is_err());
assert!(validate_total_choices(2, 100).is_err());
// Test error message
let result = validate_total_choices(100, 2);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string()
.contains("Total choices (batch_size × n = 100 × 2 = 200) exceeds maximum of 128")
);
}
#[test]
fn test_batch_prompt_with_n_parameter() {
use dynamo_async_openai::types::Prompt;
use dynamo_llm::protocols::openai::completions::get_prompt_batch_size;
// Test batch size calculation
let prompt = Prompt::StringArray(vec!["p1".to_string(), "p2".to_string(), "p3".to_string()]);
let batch_size = get_prompt_batch_size(&prompt);
let n = 2_u8;
// Total choices = batch_size × n = 3 × 2 = 6
let total_choices = batch_size * (n as usize);
assert_eq!(total_choices, 6);
// Choice indices should be:
// prompt 0: indices 0, 1
// prompt 1: indices 2, 3
// prompt 2: indices 4, 5
for prompt_idx in 0..batch_size {
for choice_idx in 0..n {
let expected_index = (prompt_idx as u32) * (n as u32) + (choice_idx as u32);
// Verify index calculation matches vLLM logic
assert_eq!(
expected_index,
prompt_idx as u32 * n as u32 + choice_idx as u32
);
}
}
}
#[test]
fn test_single_prompt_in_array() {
use dynamo_async_openai::types::Prompt;
use dynamo_llm::protocols::openai::completions::{
extract_single_prompt, get_prompt_batch_size,
};
// Single element array should work like regular prompt
let single_in_array = Prompt::StringArray(vec!["Single prompt".to_string()]);
assert_eq!(get_prompt_batch_size(&single_in_array), 1);
assert_eq!(
extract_single_prompt(&single_in_array, 0),
Prompt::String("Single prompt".to_string())
);
}
...@@ -162,6 +162,24 @@ def test_completion_string_prompt() -> None: ...@@ -162,6 +162,24 @@ def test_completion_string_prompt() -> None:
) )
@pytest.mark.usefixtures("start_services")
@pytest.mark.e2e
@pytest.mark.model(TEST_MODEL)
def test_completion_empty_array_prompt() -> None:
payload: Dict[str, Any] = {
"model": TEST_MODEL,
"prompt": [],
"max_tokens": 2000,
}
response = _send_completion_request(payload)
assert response.status_code == 400, (
f"Completion request should failed with status 400 but got"
f"{response.status_code}: {response.text}"
)
@pytest.mark.usefixtures("start_services") @pytest.mark.usefixtures("start_services")
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(TEST_MODEL) @pytest.mark.model(TEST_MODEL)
...@@ -186,13 +204,25 @@ def test_completion_single_element_array_prompt() -> None: ...@@ -186,13 +204,25 @@ def test_completion_single_element_array_prompt() -> None:
def test_completion_multi_element_array_prompt() -> None: def test_completion_multi_element_array_prompt() -> None:
payload: Dict[str, Any] = { payload: Dict[str, Any] = {
"model": TEST_MODEL, "model": TEST_MODEL,
"prompt": ["Tell me about Mars", "Tell me about Ceres"], "prompt": [
"max_tokens": 2000, "Tell me about Mars",
"Tell me about Ceres",
"Tell me about Jupiter",
],
"max_tokens": 300,
} }
response = _send_completion_request(payload) response = _send_completion_request(payload)
response_data = response.json()
assert response.status_code == 200, (
f"Completion request failed with status "
f"{response.status_code}: {response.text}"
)
expected_choices = len(payload.get("prompt")) # type: ignore
choices = len(response_data.get("choices", []))
# request should fail because we are sending multiple prompts
assert ( assert (
response.status_code == 500 expected_choices == choices
), f"Request should fail with code 500; response:{response.text}" ), f"Expected {expected_choices} choices, got {choices}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment