Unverified Commit 4d3e1ae3 authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

feat: Full Anthropic Messages API cache_control support (top-level, per-block,...


feat: Full Anthropic Messages API cache_control support (top-level, per-block, system block arrays) (#6629)
Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
parent a3cf35c3
...@@ -272,6 +272,9 @@ impl OpenAIPreprocessor { ...@@ -272,6 +272,9 @@ impl OpenAIPreprocessor {
builder.mdc_sum(Some(self.mdcsum.clone())); builder.mdc_sum(Some(self.mdcsum.clone()));
let lora_name = self.lora_name.clone(); let lora_name = self.lora_name.clone();
// Extract cache_control TTL from either nvext or top-level field
let cache_control_ttl = request.effective_cache_control().map(|cc| cc.ttl_seconds());
// Extract routing hints from nvext if present // Extract routing hints from nvext if present
if let Some(nvext) = request.nvext() { if let Some(nvext) = request.nvext() {
// Build routing hints from nvext fields // Build routing hints from nvext fields
...@@ -289,10 +292,12 @@ impl OpenAIPreprocessor { ...@@ -289,10 +292,12 @@ impl OpenAIPreprocessor {
allowed_worker_ids: None, allowed_worker_ids: None,
}; };
builder.routing(Some(routing)); builder.routing(Some(routing));
} else if lora_name.is_some() { } else if lora_name.is_some() || cache_control_ttl.is_some() {
// Ensure LoRA-aware routing still gets hints even when nvext is absent. // Ensure routing hints exist when we have LoRA or cache_control,
// even when nvext is absent (e.g. Anthropic endpoint requests).
builder.routing(Some(RoutingHints { builder.routing(Some(RoutingHints {
lora_name, lora_name,
cache_control_ttl,
..Default::default() ..Default::default()
})); }));
} }
......
...@@ -30,6 +30,7 @@ pub struct AnthropicStreamConverter { ...@@ -30,6 +30,7 @@ pub struct AnthropicStreamConverter {
// Token usage (from engine) // Token usage (from engine)
input_token_count: u32, input_token_count: u32,
output_token_count: u32, output_token_count: u32,
cached_token_count: Option<u32>,
// Tool call tracking // Tool call tracking
tool_call_states: Vec<ToolCallState>, tool_call_states: Vec<ToolCallState>,
tool_calls_sent: HashSet<String>, tool_calls_sent: HashSet<String>,
...@@ -57,6 +58,7 @@ impl AnthropicStreamConverter { ...@@ -57,6 +58,7 @@ impl AnthropicStreamConverter {
text_block_index: 0, text_block_index: 0,
input_token_count: 0, input_token_count: 0,
output_token_count: 0, output_token_count: 0,
cached_token_count: None,
tool_call_states: Vec::new(), tool_call_states: Vec::new(),
tool_calls_sent: HashSet::new(), tool_calls_sent: HashSet::new(),
next_block_index: 0, next_block_index: 0,
...@@ -77,6 +79,8 @@ impl AnthropicStreamConverter { ...@@ -77,6 +79,8 @@ impl AnthropicStreamConverter {
usage: AnthropicUsage { usage: AnthropicUsage {
input_tokens: 0, input_tokens: 0,
output_tokens: 0, output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}, },
}; };
...@@ -95,6 +99,10 @@ impl AnthropicStreamConverter { ...@@ -95,6 +99,10 @@ impl AnthropicStreamConverter {
if let Some(usage) = &chunk.usage { if let Some(usage) = &chunk.usage {
self.input_token_count = usage.prompt_tokens; self.input_token_count = usage.prompt_tokens;
self.output_token_count = usage.completion_tokens; self.output_token_count = usage.completion_tokens;
self.cached_token_count = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
} }
for choice in &chunk.choices { for choice in &chunk.choices {
...@@ -138,6 +146,7 @@ impl AnthropicStreamConverter { ...@@ -138,6 +146,7 @@ impl AnthropicStreamConverter {
index: self.text_block_index, index: self.text_block_index,
content_block: AnthropicResponseContentBlock::Text { content_block: AnthropicResponseContentBlock::Text {
text: String::new(), text: String::new(),
citations: None,
}, },
}; };
events.push(make_sse_event("content_block_start", &block_start)); events.push(make_sse_event("content_block_start", &block_start));
...@@ -271,6 +280,8 @@ impl AnthropicStreamConverter { ...@@ -271,6 +280,8 @@ impl AnthropicStreamConverter {
usage: AnthropicUsage { usage: AnthropicUsage {
input_tokens: self.input_token_count, input_tokens: self.input_token_count,
output_tokens: self.output_token_count, output_tokens: self.output_token_count,
cache_creation_input_tokens: None,
cache_read_input_tokens: self.cached_token_count,
}, },
}; };
events.push(make_sse_event("message_delta", &message_delta)); events.push(make_sse_event("message_delta", &message_delta));
...@@ -329,6 +340,10 @@ impl AnthropicStreamConverter { ...@@ -329,6 +340,10 @@ impl AnthropicStreamConverter {
if let Some(usage) = &chunk.usage { if let Some(usage) = &chunk.usage {
self.input_token_count = usage.prompt_tokens; self.input_token_count = usage.prompt_tokens;
self.output_token_count = usage.completion_tokens; self.output_token_count = usage.completion_tokens;
self.cached_token_count = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
} }
for choice in &chunk.choices { for choice in &chunk.choices {
...@@ -369,6 +384,7 @@ impl AnthropicStreamConverter { ...@@ -369,6 +384,7 @@ impl AnthropicStreamConverter {
index: self.text_block_index, index: self.text_block_index,
content_block: AnthropicResponseContentBlock::Text { content_block: AnthropicResponseContentBlock::Text {
text: String::new(), text: String::new(),
citations: None,
}, },
}; };
events.push(make_tagged_event("content_block_start", &ev)); events.push(make_tagged_event("content_block_start", &ev));
...@@ -483,6 +499,8 @@ impl AnthropicStreamConverter { ...@@ -483,6 +499,8 @@ impl AnthropicStreamConverter {
usage: AnthropicUsage { usage: AnthropicUsage {
input_tokens: self.input_token_count, input_tokens: self.input_token_count,
output_tokens: self.output_token_count, output_tokens: self.output_token_count,
cache_creation_input_tokens: None,
cache_read_input_tokens: self.cached_token_count,
}, },
}; };
events.push(make_tagged_event("message_delta", &ev)); events.push(make_tagged_event("message_delta", &ev));
......
This diff is collapsed.
...@@ -92,6 +92,10 @@ impl NvExtProvider for NvCreateChatCompletionRequest { ...@@ -92,6 +92,10 @@ impl NvExtProvider for NvCreateChatCompletionRequest {
fn raw_prompt(&self) -> Option<String> { fn raw_prompt(&self) -> Option<String> {
None None
} }
fn effective_cache_control(&self) -> Option<&crate::protocols::openai::nvext::CacheControl> {
NvExtProvider::nvext(self).and_then(|ext| ext.cache_control.as_ref())
}
} }
/// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`, /// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`,
......
...@@ -49,6 +49,13 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap) ...@@ -49,6 +49,13 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap)
pub trait NvExtProvider { pub trait NvExtProvider {
fn nvext(&self) -> Option<&NvExt>; fn nvext(&self) -> Option<&NvExt>;
fn raw_prompt(&self) -> Option<String>; fn raw_prompt(&self) -> Option<String>;
/// Return the effective cache control for this request.
/// Default: delegates to `nvext.cache_control`. Implementations may override
/// to also check a top-level `cache_control` field (see `NvCreateChatCompletionRequest`).
fn effective_cache_control(&self) -> Option<&CacheControl> {
self.nvext().and_then(|ext| ext.cache_control.as_ref())
}
} }
/// Worker ID information for disaggregated serving /// Worker ID information for disaggregated serving
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment