Unverified Commit f7fe8005 authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

fix: populate input_tokens in Anthropic streaming + add model retrieval endpoint (#7234)


Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
parent 49087845
......@@ -7,8 +7,10 @@
//! chat completions, processed by the existing engine, and responses/streams
//! are converted back to Anthropic format.
use std::collections::HashSet;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use axum::{
Json, Router,
......@@ -20,7 +22,7 @@ use axum::{
IntoResponse, Response,
sse::{KeepAlive, Sse},
},
routing::post,
routing::{get, post},
};
use dynamo_runtime::config::{env_is_truthy, environment_names::llm as env_llm};
use dynamo_runtime::pipeline::{AsyncEngineContextProvider, Context};
......@@ -73,6 +75,27 @@ pub fn anthropic_messages_router(
(vec![doc, count_doc], router)
}
/// Creates the router for model listing and retrieval.
///
/// When the `anthropic-version` header is present, returns the Anthropic model
/// format (with `context_window`, `display_name`, etc.). Otherwise returns the
/// standard OpenAI format. This keeps Anthropic-specific content negotiation
/// out of the OpenAI handler.
pub fn anthropic_models_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let models_path = path.unwrap_or("/v1/models".to_string());
let retrieve_path = format!("{}/{{*model_id}}", models_path);
let list_doc = RouteDoc::new(axum::http::Method::GET, &models_path);
let retrieve_doc = RouteDoc::new(axum::http::Method::GET, &retrieve_path);
let router = Router::new()
.route(&models_path, get(list_models))
.route(&retrieve_path, get(get_model))
.with_state(state);
(vec![list_doc, retrieve_doc], router)
}
// ---------------------------------------------------------------------------
// Error middleware
// ---------------------------------------------------------------------------
......@@ -214,6 +237,14 @@ async fn anthropic_messages(
.as_ref()
.is_some_and(|t| t.thinking_type == "disabled");
// Estimate input tokens before consuming the request via try_into().
// Only used in the streaming path to populate message_start.
let estimated_input_tokens = if streaming {
estimate_input_tokens(&orig_request)
} else {
0
};
// Convert Anthropic request -> UnifiedRequest -> Chat Completion request
let unified_request: UnifiedRequest = orig_request.try_into().map_err(|e: anyhow::Error| {
tracing::error!(
......@@ -334,8 +365,10 @@ async fn anthropic_messages(
use std::sync::atomic::{AtomicBool, Ordering};
let mut converter = match anthropic_ctx {
Some(ctx) => AnthropicStreamConverter::with_context(model_for_resp, ctx),
None => AnthropicStreamConverter::new(model_for_resp),
Some(ctx) => {
AnthropicStreamConverter::with_context(model_for_resp, estimated_input_tokens, ctx)
}
None => AnthropicStreamConverter::new(model_for_resp, estimated_input_tokens),
};
let start_events = converter.emit_start_events();
......@@ -465,6 +498,195 @@ async fn handler_count_tokens(
.into_response())
}
// ---------------------------------------------------------------------------
// Model listing / retrieval (content-negotiating)
// ---------------------------------------------------------------------------
/// Build a lookup of model display_name -> context_length from model cards.
fn build_model_context_map(state: &service_v2::State) -> std::collections::HashMap<String, u32> {
state
.manager()
.get_model_cards()
.iter()
.map(|c| (c.display_name.clone(), c.context_length))
.collect()
}
/// Read optional env var overrides for context window and max output tokens.
fn model_env_overrides() -> (Option<u64>, Option<u64>) {
let context_window = match std::env::var("DYN_CONTEXT_WINDOW") {
Ok(v) => match v.parse::<u64>() {
Ok(val) => Some(val),
Err(_) => {
tracing::warn!("Invalid DYN_CONTEXT_WINDOW value '{}', ignoring", v);
None
}
},
Err(_) => None,
};
let max_output_tokens = match std::env::var("DYN_MAX_OUTPUT_TOKENS") {
Ok(v) => match v.parse::<u64>() {
Ok(val) => Some(val),
Err(_) => {
tracing::warn!("Invalid DYN_MAX_OUTPUT_TOKENS value '{}', ignoring", v);
None
}
},
Err(_) => None,
};
(context_window, max_output_tokens)
}
/// Resolve context_window for a model: env override takes precedence over MDC.
fn resolve_context_window(
model_name: &str,
card_map: &std::collections::HashMap<String, u32>,
env_override: Option<u64>,
) -> Option<u64> {
env_override.or_else(|| card_map.get(model_name).map(|&cl| cl as u64))
}
/// List all models. Returns Anthropic format when `anthropic-version` header
/// is present, otherwise OpenAI format.
async fn list_models(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
) -> Result<Response, super::openai::ErrorResponse> {
super::openai::check_ready(&state)?;
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let models: HashSet<String> = state.manager().model_display_names();
let card_map = build_model_context_map(&state);
let (cw_override, mot_override) = model_env_overrides();
if headers.contains_key("anthropic-version") {
let created_at = chrono::DateTime::from_timestamp(created as i64, 0)
.unwrap_or_default()
.format("%Y-%m-%dT%H:%M:%SZ")
.to_string();
let data: Vec<serde_json::Value> = models
.iter()
.map(|name| {
let mut obj = serde_json::json!({
"id": name,
"display_name": name,
"type": "model",
"created_at": created_at,
});
if let Some(cw) = resolve_context_window(name, &card_map, cw_override) {
obj["max_input_tokens"] = serde_json::json!(cw);
}
if let Some(mot) = mot_override {
obj["max_tokens"] = serde_json::json!(mot);
}
obj
})
.collect();
let first_id = data
.first()
.and_then(|d| d["id"].as_str().map(String::from));
let last_id = data.last().and_then(|d| d["id"].as_str().map(String::from));
return Ok(Json(serde_json::json!({
"data": data,
"has_more": false,
"first_id": first_id,
"last_id": last_id,
}))
.into_response());
}
// OpenAI format fallback
let data: Vec<serde_json::Value> = models
.iter()
.map(|name| {
let mut obj = serde_json::json!({
"id": name,
"object": "model",
"created": created,
"owned_by": "nvidia",
});
if let Some(cw) = resolve_context_window(name, &card_map, cw_override) {
obj["context_window"] = serde_json::json!(cw);
}
if let Some(mot) = mot_override {
obj["max_output_tokens"] = serde_json::json!(mot);
}
obj
})
.collect();
Ok(Json(serde_json::json!({
"object": "list",
"data": data,
}))
.into_response())
}
/// Retrieve a single model by ID. Returns Anthropic format when
/// `anthropic-version` header is present, otherwise OpenAI format.
///
/// The model ID may contain slashes (e.g. `Qwen/Qwen3.5-35B-A3B-FP8`),
/// which is why this uses a wildcard `/{*model_id}` path parameter.
async fn get_model(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
axum::extract::Path(model_id): axum::extract::Path<String>,
) -> Result<Response, super::openai::ErrorResponse> {
super::openai::check_ready(&state)?;
// Strip leading slash from wildcard capture (axum `/{*key}` includes it).
let model_id = model_id.strip_prefix('/').unwrap_or(&model_id);
let models: HashSet<String> = state.manager().model_display_names();
if !models.contains(model_id) {
return Err(super::openai::ErrorMessage::model_not_found());
}
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let card_map = build_model_context_map(&state);
let (cw_override, mot_override) = model_env_overrides();
let context_window = resolve_context_window(model_id, &card_map, cw_override);
if headers.contains_key("anthropic-version") {
let created_at = chrono::DateTime::from_timestamp(created as i64, 0)
.unwrap_or_default()
.format("%Y-%m-%dT%H:%M:%SZ")
.to_string();
let mut obj = serde_json::json!({
"id": model_id,
"display_name": model_id,
"type": "model",
"created_at": created_at,
});
if let Some(cw) = context_window {
obj["max_input_tokens"] = serde_json::json!(cw);
}
if let Some(mot) = mot_override {
obj["max_tokens"] = serde_json::json!(mot);
}
Ok(Json(obj).into_response())
} else {
let mut obj = serde_json::json!({
"id": model_id,
"object": "model",
"created": created,
"owned_by": "nvidia",
});
if let Some(cw) = context_window {
obj["context_window"] = serde_json::json!(cw);
}
if let Some(mot) = mot_override {
obj["max_output_tokens"] = serde_json::json!(mot);
}
Ok(Json(obj).into_response())
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
......@@ -485,6 +707,23 @@ fn strip_billing_preamble(system: &mut Option<SystemContent>) {
}
}
/// Estimate input token count for an Anthropic request.
///
/// Uses the same heuristic as `AnthropicCountTokensRequest::estimate_tokens()`
/// (sum character lengths / 3). This populates `input_tokens` in the streaming
/// `message_start` event, since the engine only reports prompt token counts on
/// the final chunk.
fn estimate_input_tokens(req: &AnthropicCreateMessageRequest) -> u32 {
// Build a temporary count-tokens request to reuse the existing estimator.
let count_req = AnthropicCountTokensRequest {
model: req.model.clone(),
messages: req.messages.clone(),
system: req.system.clone(),
tools: req.tools.clone(),
};
count_req.estimate_tokens()
}
/// Build an Anthropic-formatted error response.
/// Maps HTTP status codes to Anthropic error types following the Anthropic API spec.
fn anthropic_error(status: StatusCode, error_type: &str, message: &str) -> Response {
......
......@@ -1812,7 +1812,7 @@ pub fn validate_response_unsupported_fields(
// todo - abstract this to the top level lib.rs to be reused
// todo - move the service_observer to its own state/arc
fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), ErrorResponse> {
pub(crate) fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), ErrorResponse> {
// if state.service_observer.stage() != ServiceStage::Ready {
// return Err(ErrorMessage::service_unavailable());
// }
......@@ -1841,15 +1841,34 @@ async fn list_models_openai(
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
// Build context_length lookup from model deployment cards
let cards = state.manager().get_model_cards();
let card_map: HashMap<String, u32> = cards
.iter()
.map(|c| (c.display_name.clone(), c.context_length))
.collect();
// Env var overrides (take precedence over MDC values)
let cw_override: Option<u64> = std::env::var("DYN_CONTEXT_WINDOW")
.ok()
.and_then(|v| v.parse().ok());
let mot_override: Option<u64> = std::env::var("DYN_MAX_OUTPUT_TOKENS")
.ok()
.and_then(|v| v.parse().ok());
let mut data = Vec::new();
let models: HashSet<String> = state.manager().model_display_names();
for model_name in models {
let context_window = cw_override.or_else(|| card_map.get(&model_name).map(|&cl| cl as u64));
data.push(ModelListing {
id: model_name.clone(),
object: "model", // Per OpenAI spec, this should be "model"
object: "model",
created,
owned_by: "nvidia".to_string(),
context_window,
max_output_tokens: mot_override,
});
}
......@@ -1872,6 +1891,10 @@ struct ModelListing {
object: &'static str, // always "model" per OpenAI spec
created: u64, // Seconds since epoch
owned_by: String,
#[serde(skip_serializing_if = "Option::is_none")]
context_window: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u64>,
}
/// Create an Axum [`Router`] for the OpenAI API Completions endpoint
......@@ -1930,13 +1953,62 @@ pub fn list_models_router(
) -> (Vec<RouteDoc>, Router) {
// Standard OpenAI compatible list models endpoint
let openai_path = path.unwrap_or("/v1/models".to_string());
let retrieve_path = format!("{}/{{*model_id}}", openai_path);
let doc_for_openai = RouteDoc::new(axum::http::Method::GET, &openai_path);
let doc_for_retrieve = RouteDoc::new(axum::http::Method::GET, &retrieve_path);
let router = Router::new()
.route(&openai_path, get(list_models_openai))
.route(&retrieve_path, get(get_model_openai))
.with_state(state);
(vec![doc_for_openai], router)
(vec![doc_for_openai, doc_for_retrieve], router)
}
/// Retrieve a single model by ID (OpenAI format).
///
/// Per the OpenAI API spec: `GET /v1/models/{model}` returns a model object.
/// Uses wildcard path to support model IDs with slashes (e.g. `Qwen/Qwen3.5-35B-A3B-FP8`).
async fn get_model_openai(
State(state): State<Arc<service_v2::State>>,
axum::extract::Path(model_id): axum::extract::Path<String>,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
let model_id = model_id.strip_prefix('/').unwrap_or(&model_id);
let models: HashSet<String> = state.manager().model_display_names();
if !models.contains(model_id) {
return Err(ErrorMessage::model_not_found());
}
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let cards = state.manager().get_model_cards();
let context_length = cards
.iter()
.find(|c| c.display_name == model_id)
.map(|c| c.context_length as u64);
let context_window: Option<u64> = std::env::var("DYN_CONTEXT_WINDOW")
.ok()
.and_then(|v| v.parse().ok())
.or(context_length);
let max_output_tokens: Option<u64> = std::env::var("DYN_MAX_OUTPUT_TOKENS")
.ok()
.and_then(|v| v.parse().ok());
Ok(Json(ModelListing {
id: model_id.to_string(),
object: "model",
created,
owned_by: "nvidia".to_string(),
context_window,
max_output_tokens,
})
.into_response())
}
/// Create an Axum [`Router`] for the OpenAI API Responses endpoint
......
......@@ -524,7 +524,14 @@ impl HttpServiceConfigBuilder {
var(HTTP_SVC_METRICS_PATH_ENV).ok(),
config.drt_metrics,
),
super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()),
if env_is_truthy(env_llm::DYN_ENABLE_ANTHROPIC_API) {
super::anthropic::anthropic_models_router(
state.clone(),
var(HTTP_SVC_MODELS_PATH_ENV).ok(),
)
} else {
super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok())
},
super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()),
super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()),
super::busy_threshold::busy_threshold_router(state.clone(), None),
......
......@@ -59,7 +59,7 @@ struct ToolCallState {
}
impl AnthropicStreamConverter {
pub fn new(model: String) -> Self {
pub fn new(model: String, estimated_input_tokens: u32) -> Self {
Self {
model,
message_id: format!("msg_{}", Uuid::new_v4().simple()),
......@@ -70,7 +70,7 @@ impl AnthropicStreamConverter {
text_block_started: false,
text_block_closed: false,
text_block_index: 0,
input_token_count: 0,
input_token_count: estimated_input_tokens,
output_token_count: 0,
cached_token_count: None,
tool_call_states: Vec::new(),
......@@ -83,8 +83,12 @@ impl AnthropicStreamConverter {
/// Create a converter seeded with the original Anthropic request context.
/// This allows the response stream to carry forward metadata that was lost
/// during the Anthropic-to-OpenAI request conversion.
pub fn with_context(model: String, context: AnthropicContext) -> Self {
let mut converter = Self::new(model);
pub fn with_context(
model: String,
estimated_input_tokens: u32,
context: AnthropicContext,
) -> Self {
let mut converter = Self::new(model, estimated_input_tokens);
converter.api_context = Some(context);
converter
}
......@@ -102,7 +106,7 @@ impl AnthropicStreamConverter {
stop_reason: None,
stop_sequence: None,
usage: AnthropicUsage {
input_tokens: 0,
input_tokens: self.input_token_count,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
......@@ -120,9 +124,11 @@ impl AnthropicStreamConverter {
) -> Vec<Result<Event, anyhow::Error>> {
let mut events = Vec::new();
// Capture real token usage from engine when available (typically on the final chunk).
// Capture token usage from engine when available (typically on the final chunk).
// Only update output_token_count — input_token_count is set once from the
// estimate in new() and must stay consistent between message_start and
// message_delta to avoid Claude Code's token display jumping.
if let Some(usage) = &chunk.inner.usage {
self.input_token_count = usage.prompt_tokens;
self.output_token_count = usage.completion_tokens;
self.cached_token_count = usage
.prompt_tokens_details
......@@ -463,7 +469,6 @@ impl AnthropicStreamConverter {
let mut events = Vec::new();
if let Some(usage) = &chunk.inner.usage {
self.input_token_count = usage.prompt_tokens;
self.output_token_count = usage.completion_tokens;
self.cached_token_count = usage
.prompt_tokens_details
......@@ -820,7 +825,7 @@ mod tests {
/// events and fail to execute tool calls ("Error editing file").
#[test]
fn test_text_block_stops_before_tool_block_starts() {
let mut conv = AnthropicStreamConverter::new("test-model".into());
let mut conv = AnthropicStreamConverter::new("test-model".into(), 0);
// Stream some text
let text_events = conv.process_chunk_tagged(&text_chunk("I'll edit the file."));
......@@ -881,7 +886,7 @@ mod tests {
/// Tool-only response (no preceding text): no spurious stop events.
#[test]
fn test_tool_only_response_no_text_block() {
let mut conv = AnthropicStreamConverter::new("test-model".into());
let mut conv = AnthropicStreamConverter::new("test-model".into(), 0);
let tool_events = conv.process_chunk_tagged(&tool_call_chunk(
0,
......@@ -910,7 +915,7 @@ mod tests {
/// Text-only response: stop emitted in end events (no early close).
#[test]
fn test_text_only_response_stop_in_end_events() {
let mut conv = AnthropicStreamConverter::new("test-model".into());
let mut conv = AnthropicStreamConverter::new("test-model".into(), 0);
conv.process_chunk_tagged(&text_chunk("Hello world"));
......@@ -960,7 +965,7 @@ mod tests {
/// block is properly closed before the next one starts.
#[test]
fn test_thinking_text_then_tool_call() {
let mut conv = AnthropicStreamConverter::new("test-model".into());
let mut conv = AnthropicStreamConverter::new("test-model".into(), 0);
// 1. Reasoning tokens → thinking block starts
let ev = conv.process_chunk_tagged(&reasoning_chunk("Let me think..."));
......@@ -1027,7 +1032,7 @@ mod tests {
/// Thinking-only response (no text/tool follows): thinking block closed in end events.
#[test]
fn test_thinking_only_closed_in_end_events() {
let mut conv = AnthropicStreamConverter::new("test-model".into());
let mut conv = AnthropicStreamConverter::new("test-model".into(), 0);
conv.process_chunk_tagged(&reasoning_chunk("Deep thought..."));
let ev = conv.emit_end_events_tagged();
......@@ -1045,7 +1050,7 @@ mod tests {
/// Multiple tool calls: each gets inline content_block_stop.
#[test]
fn test_multiple_tool_calls_each_stopped_inline() {
let mut conv = AnthropicStreamConverter::new("test-model".into());
let mut conv = AnthropicStreamConverter::new("test-model".into(), 0);
let events1 = conv.process_chunk_tagged(&tool_call_chunk(
0,
......@@ -1098,7 +1103,7 @@ mod tests {
service_tier: Some("priority".to_string()),
..Default::default()
};
let mut conv = AnthropicStreamConverter::with_context("test-model".into(), ctx);
let mut conv = AnthropicStreamConverter::with_context("test-model".into(), 0, ctx);
assert!(conv.api_context.is_some());
assert_eq!(
conv.api_context.as_ref().unwrap().service_tier.as_deref(),
......
......@@ -29,6 +29,9 @@ use crate::protocols::openai::chat_completions::{
};
use crate::protocols::openai::common_ext::CommonExt;
// ---------------------------------------------------------------------------
// Conversion: AnthropicCreateMessageRequest -> NvCreateChatCompletionRequest
// ---------------------------------------------------------------------------
impl TryFrom<AnthropicCreateMessageRequest> for NvCreateChatCompletionRequest {
type Error = anyhow::Error;
......@@ -541,6 +544,7 @@ pub fn chat_completion_to_anthropic_response(
usage,
}
}
#[cfg(test)]
mod tests {
use super::*;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment