Unverified Commit be9d6b2b authored by Vladislav Nosivskoy's avatar Vladislav Nosivskoy Committed by GitHub
Browse files

feat: support prompt_tokens_details in usage (#4239)


Signed-off-by: default avatarVladislav Nosivskoy <vladnosiv@gmail.com>
parent 0f4d7634
...@@ -120,6 +120,7 @@ class StandaloneRouterHandler: ...@@ -120,6 +120,7 @@ class StandaloneRouterHandler:
"index": worker_output.get("index"), "index": worker_output.get("index"),
"disaggregated_params": worker_output.get("disaggregated_params"), "disaggregated_params": worker_output.get("disaggregated_params"),
"extra_args": worker_output.get("extra_args"), "extra_args": worker_output.get("extra_args"),
"completion_usage": worker_output.get("completion_usage"),
} }
yield llm_engine_output yield llm_engine_output
......
...@@ -229,6 +229,19 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -229,6 +229,19 @@ class DecodeWorkerHandler(BaseWorkerHandler):
next_total_toks = len(output_ids) next_total_toks = len(output_ids)
out["token_ids"] = output_ids[num_output_tokens_so_far:] out["token_ids"] = output_ids[num_output_tokens_so_far:]
num_output_tokens_so_far = next_total_toks num_output_tokens_so_far = next_total_toks
if finish_reason:
input_tokens = res["meta_info"]["prompt_tokens"]
completion_tokens = res["meta_info"]["completion_tokens"]
cached_tokens = res["meta_info"]["cached_tokens"]
prefill_prompt_tokens_details = None
if cached_tokens is not None and cached_tokens > 0:
prefill_prompt_tokens_details = {"cached_tokens": cached_tokens}
out["completion_usage"] = {
"prompt_tokens": input_tokens,
"completion_tokens": completion_tokens,
"total_tokens": input_tokens + completion_tokens,
"prompt_tokens_details": prefill_prompt_tokens_details,
}
if not context.is_stopped(): if not context.is_stopped():
yield out yield out
......
...@@ -242,6 +242,10 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -242,6 +242,10 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params = SamplingParams() default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer) default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None default_sampling_params.stop = None
# Enable perf metrics so prompt_tokens_details can be returned
if hasattr(default_sampling_params, "return_perf_metrics"):
default_sampling_params.return_perf_metrics = True
model_input = ModelInput.Tokens model_input = ModelInput.Tokens
# Set model type based on disaggregation mode for unified frontend support # Set model type based on disaggregation mode for unified frontend support
...@@ -356,6 +360,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -356,6 +360,7 @@ async def init(runtime: DistributedRuntime, config: Config):
connector=connector, connector=connector,
runtime=runtime, # Pass runtime for graceful shutdown runtime=runtime, # Pass runtime for graceful shutdown
metrics_collector=metrics_collector, metrics_collector=metrics_collector,
kv_block_size=config.kv_block_size,
) )
# Register the model with runtime config # Register the model with runtime config
......
...@@ -72,6 +72,7 @@ class RequestHandlerConfig: ...@@ -72,6 +72,7 @@ class RequestHandlerConfig:
DistributedRuntime DistributedRuntime
] = None # DistributedRuntime reference for graceful shutdown ] = None # DistributedRuntime reference for graceful shutdown
metrics_collector: Optional[Any] = None # TensorRT-LLM MetricsCollector metrics_collector: Optional[Any] = None # TensorRT-LLM MetricsCollector
kv_block_size: int = 32
class HandlerBase: class HandlerBase:
...@@ -92,6 +93,7 @@ class HandlerBase: ...@@ -92,6 +93,7 @@ class HandlerBase:
self.connector = config.connector self.connector = config.connector
# Store runtime reference for graceful shutdown # Store runtime reference for graceful shutdown
self.runtime = config.runtime self.runtime = config.runtime
self.kv_block_size: int = config.kv_block_size
def check_error(self, result: dict): def check_error(self, result: dict):
""" """
...@@ -208,11 +210,13 @@ class HandlerBase: ...@@ -208,11 +210,13 @@ class HandlerBase:
request["stop_conditions"]["max_tokens"] = 1 request["stop_conditions"]["max_tokens"] = 1
disaggregated_params = LlmDisaggregatedParams(request_type="context_only") disaggregated_params = LlmDisaggregatedParams(request_type="context_only")
if "disaggregated_params" in request: if "prefill_result" in request:
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
raise ValueError("Cannot provide disaggregated_params in prefill mode") raise ValueError("Cannot provide disaggregated_params in prefill mode")
disaggregated_params = DisaggregatedParamsCodec.decode( disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**request["disaggregated_params"]) DisaggregatedParams(
**request["prefill_result"].get("disaggregated_params")
)
) )
disaggregated_params.request_type = "generation_only" disaggregated_params.request_type = "generation_only"
...@@ -258,6 +262,11 @@ class HandlerBase: ...@@ -258,6 +262,11 @@ class HandlerBase:
adapters = create_trtllm_adapters(processors) adapters = create_trtllm_adapters(processors)
sampling_params.logits_processor = adapters sampling_params.logits_processor = adapters
prefill_result = request.get("prefill_result")
prefill_prompt_tokens_details = (
prefill_result.get("prompt_tokens_details") if prefill_result else None
)
try: try:
# NEW: Updated engine call to include multimodal data # NEW: Updated engine call to include multimodal data
generation_result = self.engine.llm.generate_async( generation_result = self.engine.llm.generate_async(
...@@ -298,6 +307,34 @@ class HandlerBase: ...@@ -298,6 +307,34 @@ class HandlerBase:
DisaggregatedParamsCodec.encode(output.disaggregated_params) DisaggregatedParamsCodec.encode(output.disaggregated_params)
) )
if out.get("finish_reason"):
num_input_tokens = len(request.get("token_ids", []))
prompt_tokens_details = None
if prefill_prompt_tokens_details:
prompt_tokens_details = prefill_prompt_tokens_details
else:
if output.request_perf_metrics is not None:
kv_cache_metrics = (
output.request_perf_metrics.kv_cache_metrics
)
cached_tokens = min(
num_input_tokens,
kv_cache_metrics.num_reused_blocks
* self.kv_block_size,
)
if cached_tokens > 0:
prompt_tokens_details = {
"cached_tokens": int(cached_tokens),
}
out["completion_usage"] = {
"prompt_tokens": int(num_input_tokens),
"completion_tokens": int(next_total_toks),
"total_tokens": int(num_input_tokens + next_total_toks),
"prompt_tokens_details": prompt_tokens_details,
}
if res.finished and not out.get("finish_reason"): if res.finished and not out.get("finish_reason"):
out["finish_reason"] = "unknown" out["finish_reason"] = "unknown"
logging.warning( logging.warning(
......
...@@ -10,6 +10,7 @@ from contextlib import asynccontextmanager ...@@ -10,6 +10,7 @@ from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final from typing import Any, AsyncGenerator, Dict, Final
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
...@@ -174,6 +175,28 @@ class BaseWorkerHandler(ABC): ...@@ -174,6 +175,28 @@ class BaseWorkerHandler(ABC):
return vllm_mm_data if vllm_mm_data else None return vllm_mm_data if vllm_mm_data else None
@staticmethod
def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]:
return {
"prompt_tokens": (
len(request_output.prompt_token_ids)
if request_output.prompt_token_ids
else None
),
"completion_tokens": len(request_output.outputs[0].token_ids),
"total_tokens": (
len(request_output.prompt_token_ids)
+ len(request_output.outputs[0].token_ids)
if request_output.prompt_token_ids
else None
),
"prompt_tokens_details": (
{"cached_tokens": request_output.num_cached_tokens}
if request_output.num_cached_tokens
else None
),
}
async def generate_tokens( async def generate_tokens(
self, prompt, sampling_params, request_id, data_parallel_rank=None self, prompt, sampling_params, request_id, data_parallel_rank=None
): ):
...@@ -199,6 +222,11 @@ class BaseWorkerHandler(ABC): ...@@ -199,6 +222,11 @@ class BaseWorkerHandler(ABC):
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason: if output.finish_reason:
out["finish_reason"] = output.finish_reason out["finish_reason"] = output.finish_reason
out[
"completion_usage"
] = BaseWorkerHandler._build_completion_usage(
request_output=res
)
if output.stop_reason: if output.stop_reason:
out["stop_reason"] = output.stop_reason out["stop_reason"] = output.stop_reason
yield out yield out
...@@ -241,18 +269,24 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -241,18 +269,24 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Build sampling params from request # Build sampling params from request
sampling_params = build_sampling_params(request, self.default_sampling_params) sampling_params = build_sampling_params(request, self.default_sampling_params)
# Extract disaggregated_params from request (set by prefill router in Rust frontend) prefill_result = request.get("prefill_result")
disaggregated_params = request.get("disaggregated_params") if prefill_result and isinstance(prefill_result, dict):
if disaggregated_params: kv_params = prefill_result.get("disaggregated_params", {}).get(
# Prefill was performed - use the disaggregated params
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args["kv_transfer_params"] = disaggregated_params.get(
"kv_transfer_params" "kv_transfer_params"
) )
else:
kv_params = None
if kv_params is not None:
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args["kv_transfer_params"] = kv_params
logger.debug( logger.debug(
f"Using disaggregated params from prefill for request {request_id}" f"Using disaggregated params from prefill for request {request_id}"
) )
prefill_prompt_tokens_details = (
prefill_result.get("prompt_tokens_details") if prefill_result else None
)
dp_rank = request.get("dp_rank", None) dp_rank = request.get("dp_rank", None)
...@@ -261,6 +295,10 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -261,6 +295,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
async for tok in self.generate_tokens( async for tok in self.generate_tokens(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank prompt, sampling_params, request_id, data_parallel_rank=dp_rank
): ):
if prefill_result is not None and "completion_usage" in tok:
tok["completion_usage"][
"prompt_tokens_details"
] = prefill_prompt_tokens_details
yield tok yield tok
except EngineDeadError as e: except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}") logger.error(f"vLLM EngineDeadError: {e}")
...@@ -325,6 +363,9 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -325,6 +363,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
if res.kv_transfer_params if res.kv_transfer_params
else None else None
), ),
"completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res
),
} }
yield output yield output
......
...@@ -242,6 +242,7 @@ impl ...@@ -242,6 +242,7 @@ impl
finish_reason: data.finish_reason, finish_reason: data.finish_reason,
//mdcsum: mdcsum.clone(), //mdcsum: mdcsum.clone(),
index: data.index, index: data.index,
completion_usage: data.completion_usage,
}) })
}) })
}); });
......
...@@ -21,6 +21,7 @@ use crate::{ ...@@ -21,6 +21,7 @@ use crate::{
discovery::ModelManager, discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride}, kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::PrefillResult,
}; };
/// Errors that can occur during prefill routing /// Errors that can occur during prefill routing
...@@ -175,11 +176,11 @@ impl PrefillRouter { ...@@ -175,11 +176,11 @@ impl PrefillRouter {
Ok(()) Ok(())
} }
/// Call the prefill router and extract disaggregated_params /// Call the prefill router and extract structured prefill result
async fn call_prefill( async fn call_prefill(
&self, &self,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
) -> Result<serde_json::Value, PrefillError> { ) -> Result<PrefillResult, PrefillError> {
// Get the prefill router, error if not activated // Get the prefill router, error if not activated
let Some(prefill_router) = self.prefill_router.get() else { let Some(prefill_router) = self.prefill_router.get() else {
return Err(PrefillError::NotActivated); return Err(PrefillError::NotActivated);
...@@ -203,7 +204,22 @@ impl PrefillRouter { ...@@ -203,7 +204,22 @@ impl PrefillRouter {
)); ));
}; };
while prefill_response.next().await.is_some() {} let mut prompt_tokens_details = first_output
.data
.as_ref()
.and_then(|o| o.completion_usage.as_ref())
.and_then(|u| u.prompt_tokens_details.clone());
while let Some(next) = prefill_response.next().await {
if let Some(o) = next.data.as_ref()
&& prompt_tokens_details.is_none()
{
prompt_tokens_details = o
.completion_usage
.as_ref()
.and_then(|u| u.prompt_tokens_details.clone());
}
}
if let Some(err) = first_output.err() { if let Some(err) = first_output.err() {
return Err(PrefillError::PrefillError(format!( return Err(PrefillError::PrefillError(format!(
...@@ -223,7 +239,10 @@ impl PrefillRouter { ...@@ -223,7 +239,10 @@ impl PrefillRouter {
)); ));
}; };
Ok(disaggregated_params) Ok(PrefillResult {
disaggregated_params,
prompt_tokens_details,
})
} }
} }
...@@ -267,12 +286,12 @@ impl ...@@ -267,12 +286,12 @@ impl
// Attempt prefill and handle results // Attempt prefill and handle results
match self.call_prefill(prefill_request).await { match self.call_prefill(prefill_request).await {
Ok(disaggregated_params) => { Ok(prefill_result) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode"); tracing::debug!("Prefill succeeded, using disaggregated params for decode");
// Update request with disaggregated_params and router config
let mut decode_req = req; let mut decode_req = req;
decode_req.disaggregated_params = Some(disaggregated_params); // Update request with prefill result
decode_req.prefill_result = Some(prefill_result.clone());
// Restore original max_tokens for decode // Restore original max_tokens for decode
decode_req.stop_conditions.max_tokens = original_max_tokens; decode_req.stop_conditions.max_tokens = original_max_tokens;
......
...@@ -219,6 +219,7 @@ mod tests { ...@@ -219,6 +219,7 @@ mod tests {
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
extra_args: None, extra_args: None,
completion_usage: None,
}) })
} }
......
...@@ -308,6 +308,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -308,6 +308,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
None None
}, },
extra_args: None, extra_args: None,
completion_usage: None,
}; };
if signal.completed && token_count < max_output_tokens { if signal.completed && token_count < max_output_tokens {
......
...@@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize}; ...@@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
pub use super::FinishReason; pub use super::FinishReason;
pub use super::preprocessor::PreprocessedRequest; pub use super::preprocessor::PreprocessedRequest;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
use dynamo_async_openai::types::CompletionUsage;
use dynamo_runtime::protocols::maybe_error::MaybeError; use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>; pub type TokenType = Option<String>;
...@@ -48,6 +49,10 @@ pub struct BackendOutput { ...@@ -48,6 +49,10 @@ pub struct BackendOutput {
// Index field for batch requests to match OpenAI format // Index field for batch requests to match OpenAI format
pub index: Option<u32>, pub index: Option<u32>,
// Token usage information
#[serde(default, skip_serializing_if = "Option::is_none")]
pub completion_usage: Option<CompletionUsage>,
} }
/// The LLM engine and backnd with manage it's own state, specifically translating how a /// The LLM engine and backnd with manage it's own state, specifically translating how a
...@@ -92,6 +97,10 @@ pub struct LLMEngineOutput { ...@@ -92,6 +97,10 @@ pub struct LLMEngineOutput {
/// Additional arguments for extensibility /// Additional arguments for extensibility
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>, pub extra_args: Option<serde_json::Value>,
// Token usage information
#[serde(default, skip_serializing_if = "Option::is_none")]
pub completion_usage: Option<CompletionUsage>,
} }
impl LLMEngineOutput { impl LLMEngineOutput {
...@@ -107,6 +116,7 @@ impl LLMEngineOutput { ...@@ -107,6 +116,7 @@ impl LLMEngineOutput {
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
extra_args: None, extra_args: None,
completion_usage: None,
} }
} }
...@@ -122,6 +132,7 @@ impl LLMEngineOutput { ...@@ -122,6 +132,7 @@ impl LLMEngineOutput {
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
extra_args: None, extra_args: None,
completion_usage: None,
} }
} }
...@@ -137,6 +148,7 @@ impl LLMEngineOutput { ...@@ -137,6 +148,7 @@ impl LLMEngineOutput {
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
extra_args: None, extra_args: None,
completion_usage: None,
} }
} }
...@@ -152,6 +164,7 @@ impl LLMEngineOutput { ...@@ -152,6 +164,7 @@ impl LLMEngineOutput {
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
extra_args: None, extra_args: None,
completion_usage: None,
} }
} }
} }
......
...@@ -8,6 +8,15 @@ use super::{OutputOptions, SamplingOptions, StopConditions}; ...@@ -8,6 +8,15 @@ use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride; use crate::kv_router::RouterConfigOverride;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PrefillResult {
/// Disaggregated execution parameters
pub disaggregated_params: serde_json::Value,
/// Prompt token details produced during prefill
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<dynamo_async_openai::types::PromptTokensDetails>,
}
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub enum MultimodalData { pub enum MultimodalData {
Url(url::Url), Url(url::Url),
...@@ -69,10 +78,10 @@ pub struct PreprocessedRequest { ...@@ -69,10 +78,10 @@ pub struct PreprocessedRequest {
#[builder(default)] #[builder(default)]
pub router_config_override: Option<RouterConfigOverride>, pub router_config_override: Option<RouterConfigOverride>,
/// Disaggregated execution parameters (for prefill/decode separation) /// Structured prefill result
#[builder(default)] #[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>, pub prefill_result: Option<PrefillResult>,
/// Data parallel rank for the request (used with data parallelism) /// Data parallel rank for the request (used with data parallelism)
#[builder(default)] #[builder(default)]
......
...@@ -316,6 +316,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -316,6 +316,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
.expect("token_ids length exceeds u32::MAX"); .expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length; self.usage.completion_tokens += token_length;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if let Some(prompt_details) = delta
.completion_usage
.as_ref()
.and_then(|usage| usage.prompt_tokens_details.as_ref())
{
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
} }
let logprobs = self.create_logprobs( let logprobs = self.create_logprobs(
......
...@@ -238,6 +238,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -238,6 +238,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
.expect("token_ids length exceeds u32::MAX"); .expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length; self.usage.completion_tokens += token_length;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if let Some(prompt_details) = delta
.completion_usage
.as_ref()
.and_then(|usage| usage.prompt_tokens_details.as_ref())
{
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
} }
let logprobs = self.create_logprobs( let logprobs = self.create_logprobs(
......
...@@ -7,12 +7,17 @@ use dynamo_async_openai::types::{ ...@@ -7,12 +7,17 @@ use dynamo_async_openai::types::{
ChatCompletionRequestUserMessageContent, ChatCompletionStreamOptions, ChatCompletionRequestUserMessageContent, ChatCompletionStreamOptions,
CreateChatCompletionRequest, CreateChatCompletionRequest,
}; };
use dynamo_async_openai::types::{
CompletionUsage as AoaiCompletionUsage, CreateCompletionRequestArgs, Prompt,
PromptTokensDetails,
};
use dynamo_llm::preprocessor::OpenAIPreprocessor; use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::common::llm_backend::{BackendOutput, FinishReason}; use dynamo_llm::protocols::common::llm_backend::{BackendOutput, FinishReason};
use dynamo_llm::protocols::openai::ParsingOptions; use dynamo_llm::protocols::openai::ParsingOptions;
use dynamo_llm::protocols::openai::chat_completions::{ use dynamo_llm::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, aggregator::ChatCompletionAggregator, NvCreateChatCompletionRequest, aggregator::ChatCompletionAggregator,
}; };
use dynamo_llm::protocols::openai::completions::NvCreateCompletionRequest;
use dynamo_runtime::engine::{AsyncEngineContext, AsyncEngineStream}; use dynamo_runtime::engine::{AsyncEngineContext, AsyncEngineStream};
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt; use futures::StreamExt;
...@@ -82,8 +87,17 @@ impl AsyncEngineContext for MockContext { ...@@ -82,8 +87,17 @@ impl AsyncEngineContext for MockContext {
fn create_mock_backend_stream( fn create_mock_backend_stream(
ctx: Arc<dyn AsyncEngineContext>, ctx: Arc<dyn AsyncEngineContext>,
) -> Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>> { ) -> Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>> {
let outputs = vec![ let outputs = build_backend_outputs_with_cached_tokens(None);
// First chunk with "Hello"
let stream = stream::iter(outputs.into_iter().map(Annotated::from_data));
use dynamo_runtime::engine::ResponseStream;
ResponseStream::new(Box::pin(stream), ctx)
}
/// Build three backend outputs: "Hello", " world", "!" with optional cached_tokens on the final chunk
fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<BackendOutput> {
vec![
BackendOutput { BackendOutput {
token_ids: vec![15339], token_ids: vec![15339],
tokens: vec![Some("Hello".to_string())], tokens: vec![Some("Hello".to_string())],
...@@ -93,8 +107,8 @@ fn create_mock_backend_stream( ...@@ -93,8 +107,8 @@ fn create_mock_backend_stream(
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
index: Some(0), index: Some(0),
completion_usage: None,
}, },
// Second chunk with " world"
BackendOutput { BackendOutput {
token_ids: vec![1917], token_ids: vec![1917],
tokens: vec![Some(" world".to_string())], tokens: vec![Some(" world".to_string())],
...@@ -104,8 +118,8 @@ fn create_mock_backend_stream( ...@@ -104,8 +118,8 @@ fn create_mock_backend_stream(
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
index: Some(0), index: Some(0),
completion_usage: None,
}, },
// Third chunk with "!" and finish_reason
BackendOutput { BackendOutput {
token_ids: vec![0], token_ids: vec![0],
tokens: vec![Some("!".to_string())], tokens: vec![Some("!".to_string())],
...@@ -115,11 +129,27 @@ fn create_mock_backend_stream( ...@@ -115,11 +129,27 @@ fn create_mock_backend_stream(
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
index: Some(0), index: Some(0),
completion_usage: cached_tokens.map(|ct| AoaiCompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: Some(PromptTokensDetails {
audio_tokens: None,
cached_tokens: Some(ct),
}),
completion_tokens_details: None,
}),
}, },
]; ]
}
/// Create a backend stream from standard outputs with optional cached_tokens in the final chunk
fn create_backend_stream_with_cached_tokens(
ctx: Arc<dyn AsyncEngineContext>,
cached_tokens: Option<u32>,
) -> Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>> {
let outputs = build_backend_outputs_with_cached_tokens(cached_tokens);
let stream = stream::iter(outputs.into_iter().map(Annotated::from_data)); let stream = stream::iter(outputs.into_iter().map(Annotated::from_data));
use dynamo_runtime::engine::ResponseStream; use dynamo_runtime::engine::ResponseStream;
ResponseStream::new(Box::pin(stream), ctx) ResponseStream::new(Box::pin(stream), ctx)
} }
...@@ -308,6 +338,31 @@ async fn test_streaming_with_usage_false() { ...@@ -308,6 +338,31 @@ async fn test_streaming_with_usage_false() {
} }
} }
/// Helper to create a completion request with optional stream_options
fn create_cmpl_request(include_usage: Option<bool>, stream: bool) -> NvCreateCompletionRequest {
let inner = {
let mut builder = CreateCompletionRequestArgs::default();
builder
.model("test-model")
.prompt(Prompt::String("Hello".to_string()))
.stream(stream);
if let Some(include) = include_usage {
builder.stream_options(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: include,
});
}
builder.build().unwrap()
};
NvCreateCompletionRequest {
inner,
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
}
}
/// Helper to create a non-streaming chat completion request /// Helper to create a non-streaming chat completion request
fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest { fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User( let messages = vec![ChatCompletionRequestMessage::User(
...@@ -404,3 +459,195 @@ async fn test_nonstreaming_has_usage_field() { ...@@ -404,3 +459,195 @@ async fn test_nonstreaming_has_usage_field() {
"Total tokens should equal prompt_tokens + completion_tokens" "Total tokens should equal prompt_tokens + completion_tokens"
); );
} }
#[tokio::test]
async fn test_cmpl_streaming_with_usage_true_no_backend_usage() {
// Completions: stream=true, include_usage=true, but backend does not send completion_usage
let request = create_cmpl_request(Some(true), true);
let request_id = "cmpl-usage-none-1".to_string();
let response_generator = Box::new(request.response_generator(request_id));
// Mock backend stream (no completion_usage in any chunk)
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
// Expect 3 content chunks + 1 usage-only chunk
assert_eq!(chunks.len(), 4, "Should have 3 content + 1 usage chunk");
// First 3 chunks: usage must be None
for (i, chunk) in chunks.iter().take(3).enumerate() {
if let Some(resp) = &chunk.data {
assert!(
resp.inner.usage.is_none(),
"Content chunk {} should have usage: None",
i
);
assert!(
!resp.inner.choices.is_empty(),
"Content chunk {} should have choices",
i
);
}
}
// Final usage chunk: usage present with counts; prompt_tokens_details None (no backend usage)
if let Some(final_resp) = &chunks[3].data {
assert!(
final_resp.inner.choices.is_empty(),
"Usage-only chunk must have empty choices"
);
let usage = final_resp
.inner
.usage
.as_ref()
.expect("Usage must be present");
assert_eq!(
usage.completion_tokens, 3,
"Aggregated completion tokens should be 3"
);
assert!(
usage.prompt_tokens_details.is_none(),
"prompt_tokens_details should be None when backend does not send usage"
);
} else {
panic!("Final chunk should be present");
}
}
#[tokio::test]
async fn test_cmpl_streaming_with_cached_tokens_propagation() {
// Completions: include_usage=true, backend provides cached_tokens -> must propagate
let request = create_cmpl_request(Some(true), true);
let request_id = "cmpl-usage-cached-1".to_string();
let mut response_generator = Box::new(request.response_generator(request_id));
// Build a backend stream where the final chunk carries completion_usage with cached_tokens
let ctx = Arc::new(MockContext::new());
let backend_stream = create_backend_stream_with_cached_tokens(ctx.clone(), Some(7));
// Align ISL so total usage gets computed correctly
response_generator.update_isl(0);
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
// Expect 4 chunks total
assert_eq!(chunks.len(), 4, "Should have 3 content + 1 usage chunk");
// Final usage chunk should include cached_tokens propagated
if let Some(final_resp) = &chunks[3].data {
let usage = final_resp
.inner
.usage
.as_ref()
.expect("Usage must be present on final chunk");
let cached = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
assert_eq!(
cached,
Some(7),
"cached_tokens must propagate to final usage chunk"
);
} else {
panic!("Final chunk should be present");
}
}
#[tokio::test]
async fn test_chat_streaming_with_cached_tokens_propagation() {
// Chat Completions: include_usage=true, backend provides cached_tokens -> must propagate
let request = create_chat_request(Some(true));
let request_id = "chat-usage-cached-1".to_string();
let mut response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_backend_stream_with_cached_tokens(ctx.clone(), Some(5));
// Align ISL if needed
response_generator.update_isl(0);
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
assert_eq!(chunks.len(), 4, "Should have 3 content + 1 usage chunk");
if let Some(final_resp) = &chunks[3].data {
let usage = final_resp.usage.as_ref().expect("Usage must be present");
let cached = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
assert_eq!(
cached,
Some(5),
"cached_tokens must propagate for chat completions"
);
} else {
panic!("Final chunk should be present");
}
}
#[tokio::test]
async fn test_cmpl_nonstreaming_has_usage_and_cached_tokens() {
// Non-streaming completions must include usage in final aggregated response and propagate cached_tokens
let mut request = create_cmpl_request(None, false);
// Simulate preprocessor behavior for non-streaming
let original_stream_flag = request.inner.stream.unwrap_or(false);
request.enable_usage_for_nonstreaming(original_stream_flag);
let request_id = "cmpl-nonstream-usage".to_string();
let response_generator = Box::new(request.response_generator(request_id));
// Mock backend stream with 3 chunks, last carries completion_usage with cached_tokens
let ctx = Arc::new(MockContext::new());
let backend_stream = create_backend_stream_with_cached_tokens(ctx.clone(), Some(9));
// Transform to OpenAI completion stream
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
// Aggregate into a single non-streaming response
let parsing = ParsingOptions::default();
let result =
dynamo_llm::protocols::openai::completions::NvCreateCompletionResponse::from_annotated_stream(
transformed_stream,
parsing,
)
.await;
assert!(result.is_ok(), "Aggregation should succeed");
let resp = result.unwrap();
let usage = resp
.inner
.usage
.expect("usage must be present for non-streaming");
assert_eq!(
usage.completion_tokens, 3,
"completion_tokens must aggregate"
);
let cached = usage.prompt_tokens_details.and_then(|d| d.cached_tokens);
assert_eq!(
cached,
Some(9),
"cached_tokens must propagate to non-streaming response"
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment