Unverified Commit 7d13b6e3 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: Prevent double-tokenization when EPP picks worker (#2559)

parent 95ce83d5
...@@ -383,21 +383,29 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -383,21 +383,29 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let query_instance_id = request.has_annotation("query_instance_id"); let query_instance_id = request.has_annotation("query_instance_id");
// Extract context information before moving the request // Extract context information before moving the request
let stream_context = request.context().clone(); let stream_context = request.context().clone();
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
// if request has the annotation "query_instance_id", for example // if request has the annotation "query_instance_id", for example
// curl -d '{... ,"nvext": { "annotations": ["query_instance_id"]}}' // curl -d '{... ,"nvext": { "annotations": ["query_instance_id"]}}'
// request will not be routed to worker immediately // request will not be routed to worker immediately.
// The gateway EPP will receive the worker_instance_id and the tokens.
if query_instance_id { if query_instance_id {
let instance_id_str = instance_id.to_string(); let instance_id_str = instance_id.to_string();
let response = let response =
Annotated::from_annotation("worker_instance_id", &instance_id_str)?; Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
let stream = stream::iter(vec![response]);
// Return the tokens in nvext.token_data format
let response_tokens =
Annotated::from_annotation("token_data", &request.token_ids)?;
tracing::trace!(
"Tokens requested in the response through the query_instance_id annotation: {:?}",
response_tokens
);
let stream = stream::iter(vec![response, response_tokens]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context)); return Ok(ResponseStream::new(Box::pin(stream), stream_context));
} }
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
let mut response_stream = self.inner.direct(updated_request, instance_id).await?; let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
let stream_context = response_stream.context(); let stream_context = response_stream.context();
......
...@@ -194,7 +194,35 @@ impl OpenAIPreprocessor { ...@@ -194,7 +194,35 @@ impl OpenAIPreprocessor {
self.formatter.render(request)? self.formatter.render(request)?
}; };
let encoding = self.tokenizer.encode(&formatted_prompt)?; // Check if backend_instance_id is present and token_data is provided
let has_backend_instance_id = request
.nvext()
.and_then(|ext| ext.backend_instance_id)
.is_some();
let token_data =
request.nvext().and_then(|ext| ext.token_data.as_ref());
let (tokens_vec, skip_token_annotation) = if has_backend_instance_id {
if let Some(tokens) = token_data {
tracing::trace!(
"Using provided tokens from EPP: {} ids",
tokens.len()
);
// need ownership for the builder, so clone.
(tokens.clone(), true)
} else {
tracing::warn!(
"backend_instance_id provided but no token_data; tokenizing prompt"
);
let encoding = self.tokenizer.encode(&formatted_prompt)?;
(encoding.token_ids().to_vec(), false)
}
} else {
// No backend_instance_id provided, continue the normal flow.
let encoding = self.tokenizer.encode(&formatted_prompt)?;
(encoding.token_ids().to_vec(), false)
};
if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) { if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
annotations.insert( annotations.insert(
...@@ -203,14 +231,16 @@ impl OpenAIPreprocessor { ...@@ -203,14 +231,16 @@ impl OpenAIPreprocessor {
); );
} }
if request.has_annotation(ANNOTATION_TOKEN_IDS) { if request.has_annotation(ANNOTATION_TOKEN_IDS)
&& !skip_token_annotation
{
annotations.insert( annotations.insert(
ANNOTATION_TOKEN_IDS.to_string(), ANNOTATION_TOKEN_IDS.to_string(),
serde_json::to_string(encoding.token_ids())?, serde_json::to_string(&tokens_vec)?,
); );
} }
builder.token_ids(encoding.token_ids().to_vec()); builder.token_ids(tokens_vec);
} }
TextInput::Batch(texts) => { TextInput::Batch(texts) => {
let token_batches: Vec<Vec<u32>> = texts let token_batches: Vec<Vec<u32>> = texts
......
...@@ -68,6 +68,13 @@ pub struct NvExt { ...@@ -68,6 +68,13 @@ pub struct NvExt {
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub backend_instance_id: Option<i64>, pub backend_instance_id: Option<i64>,
/// Pre-tokenized data to use instead of tokenizing the prompt
/// If provided along with backend_instance_id, these tokens will be used directly
/// and tokenization will be skipped.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_data: Option<Vec<u32>>,
/// Guided Decoding Options /// Guided Decoding Options
/// If specified, the output will be a JSON object. Can be a string, an object, or null. /// If specified, the output will be a JSON object. Can be a string, an object, or null.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tests request with the token_data in nvext used for the EPP-aware Gateway integration. When the backend_instance_id is in the request along with the token_data the tokenization will be skipped in the preprocesor.rs
use anyhow::Result;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
#[test]
fn test_request_json_structure() -> Result<()> {
// Test that the JSON structure matches what clients will send
let json_input = r#"{
"model": "qwen",
"messages": [{"role": "user", "content": "Hello"}],
"nvext": {
"backend_instance_id": 12345,
"token_data": [15496, 1917, 264]
}
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_input)?;
// Verify parsing
assert!(request.nvext.is_some());
let nvext = request.nvext.as_ref().unwrap();
assert_eq!(nvext.backend_instance_id, Some(12345));
assert_eq!(nvext.token_data, Some(vec![15496, 1917, 264]));
Ok(())
}
...@@ -790,3 +790,198 @@ def test_kv_push_router_bindings(request, runtime_services): ...@@ -790,3 +790,198 @@ def test_kv_push_router_bindings(request, runtime_services):
if os.path.exists(mocker_args_file): if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file) os.unlink(mocker_args_file)
@pytest.mark.pre_merge
def test_query_instance_id_returns_worker_and_tokens(request, runtime_services):
"""
Test that the KV router correctly handles query_instance_id annotation.
When a request includes 'nvext.annotations': ['query_instance_id'], the router should:
1. NOT route the request to a worker immediately
2. Return worker_instance_id as an SSE event
3. Return token_data as an SSE event containing the request tokens
4. Terminate the stream with [DONE]
This tests the specific code block:
if query_instance_id {
let instance_id_str = instance_id.to_string();
let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
let stream = stream::iter(vec![response, response_tokens]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
"""
logger.info("Starting KV router query_instance_id annotation test")
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
mocker_args_file = os.path.join(request.node.name, "mocker_args.json")
os.makedirs(request.node.name, exist_ok=True)
with open(mocker_args_file, "w") as f:
json.dump(mocker_args, f)
mocker_processes = []
try:
# Start KV router (frontend)
frontend_port = PORT + 30 # Use unique port to avoid conflicts
logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(request, frontend_port)
kv_router.__enter__()
# Start multiple mocker engines to ensure worker selection logic
endpoint = "dyn://test-namespace.mocker.generate"
for i in range(NUM_MOCKERS):
logger.info(f"Starting mocker instance {i} on endpoint {endpoint}")
mocker = MockerProcess(request, endpoint, mocker_args_file)
mocker_processes.append(mocker)
for mocker in mocker_processes:
mocker.__enter__()
url = f"http://localhost:{frontend_port}/v1/chat/completions"
# Send a warming request first to ensure system is ready
logger.info("Sending warming request without annotations...")
asyncio.run(send_request_with_retry(url, TEST_PAYLOAD))
# Test payload with query_instance_id annotation
annotated_payload = {
**TEST_PAYLOAD,
"nvext": {"annotations": ["query_instance_id"]},
}
async def test_annotation_response():
"""Send request with query_instance_id and validate response structure"""
async with aiohttp.ClientSession() as session:
logger.info("Sending request with query_instance_id annotation...")
async with session.post(url, json=annotated_payload) as response:
assert (
response.status == 200
), f"Expected 200 but got {response.status}"
# Collect all response chunks
response_chunks = []
async for chunk in response.content:
if chunk:
chunk_str = chunk.decode("utf-8", errors="replace")
response_chunks.append(chunk_str)
full_response = "".join(response_chunks)
logger.info(
f"Full SSE response ({len(full_response)} bytes):\n{full_response}"
)
# Parse and validate the response structure
events = []
sse_parts = full_response.split("\n\n")
for part in sse_parts:
part = part.strip()
if not part:
continue
if part.startswith("event:"):
lines = part.split("\n")
event_line = next(
(line for line in lines if line.startswith("event:")),
None,
)
data_line = next(
(
line
for line in lines
if line.startswith("data:") or line.startswith(":")
),
None,
)
if event_line and data_line:
event_type = event_line.split(":", 1)[1].strip()
if data_line.startswith("data:"):
data_value = data_line.split(":", 1)[1].strip()
else:
data_value = data_line.split(":", 1)[1].strip()
events.append((event_type, data_value))
elif part.startswith("data:"):
data_value = part.split(":", 1)[1].strip()
logger.info(f"Parsed events: {events}")
# Validate worker_instance_id event
worker_event = next(
(e for e in events if e[0] == "worker_instance_id"), None
)
assert (
worker_event is not None
), f"Missing worker_instance_id event in: {events}"
# Validate token_data event
token_event = next(
(e for e in events if e[0] == "token_data"), None
)
assert (
token_event is not None
), f"Missing token_data event in: {events}"
token_data_str = token_event[1].strip('"')
try:
token_list = json.loads(token_data_str)
except json.JSONDecodeError as e:
raise AssertionError(
f"token_data is not valid JSON: {token_data_str}, error: {e}"
)
assert isinstance(
token_list, list
), f"token_data should be a list, got: {type(token_list)}"
assert (
len(token_list) > 0
), f"token_data should not be empty: {token_list}"
assert all(
isinstance(token, int) for token in token_list
), f"All tokens should be integers: {token_list}"
logger.info(
f"Valid token_data with {len(token_list)} tokens: {token_list[:10]}{'...' if len(token_list) > 10 else ''}"
)
# Validate that no actual generation happened (should only be metadata)
# This proves the early return worked correctly
generation_indicators = [
"choices",
"content",
"delta",
"finish_reason",
]
for indicator in generation_indicators:
assert (
indicator not in full_response.lower()
), f"Found generation indicator '{indicator}' - request should not have been routed to worker"
logger.info(
"No generation content found - early return worked correctly"
)
return {
"worker_instance_id": worker_event[1].strip('"'),
"token_count": len(token_list),
"tokens": token_list,
}
result = asyncio.run(test_annotation_response())
logger.info("Successfully validated query_instance_id annotation response:")
logger.info(f"Worker ID: {result['worker_instance_id']}")
logger.info(f"Token count: {result['token_count']}")
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
for mocker in mocker_processes:
mocker.__exit__(None, None, None)
if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment