Unverified Commit f3d784f3 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: query instance_id based on routing strategy (#1787)

parent 13560ab2
...@@ -313,19 +313,31 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -313,19 +313,31 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
InstanceSource::Dynamic(_) => { InstanceSource::Dynamic(_) => {
// Extract context ID for request tracking // Extract context ID for request tracking
let context_id = request.context().id().to_string(); let context_id = request.context().id().to_string();
let (instance_id, overlap_amount) = self let (instance_id, overlap_amount) = self
.chooser .chooser
.find_best_match(&context_id, &request.token_ids) .find_best_match(&context_id, &request.token_ids)
.await?; .await?;
let query_instance_id = request.has_annotation("query_instance_id");
// Extract context information before moving the request
let stream_context = request.context().clone();
// Update the request with the estimated prefix hit blocks // Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
let isl = backend_input.token_ids.len(); let isl = backend_input.token_ids.len();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
// if request has the annotation "query_instance_id", for example
// curl -d '{... ,"nvext": { "annotations": ["query_instance_id"]}}'
// request will not be routed to worker immediately
if query_instance_id {
let instance_id_str = instance_id.to_string();
let response =
Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
let stream = stream::iter(vec![response]);
Ok(ResponseStream::new(Box::pin(stream), stream_context))
} else {
// Get the response stream from the worker // Get the response stream from the worker
let mut response_stream = self.inner.direct(updated_request, instance_id).await?; let mut response_stream =
self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens // Wrap the stream to track tokens
let stream_context = response_stream.context(); let stream_context = response_stream.context();
...@@ -374,9 +386,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -374,9 +386,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
chooser.free(&request_id).await; chooser.free(&request_id).await;
}); });
Ok(ResponseStream::new(wrapped_stream, stream_context)) Ok(ResponseStream::new(wrapped_stream, stream_context))
} }
} }
} }
}
} }
...@@ -397,9 +397,9 @@ impl OpenAIPreprocessor { ...@@ -397,9 +397,9 @@ impl OpenAIPreprocessor {
// Only set event if not already set to avoid overriding existing events (like errors) // Only set event if not already set to avoid overriding existing events (like errors)
if response.event.is_none() { if response.event.is_none() {
response.event = metrics_annotated.event; response.event = metrics_annotated.event;
}
response.comment = metrics_annotated.comment; response.comment = metrics_annotated.comment;
} }
}
tracing::trace!( tracing::trace!(
request_id = inner.context.id(), request_id = inner.context.id(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment