Unverified Commit be9d6b2b authored by Vladislav Nosivskoy's avatar Vladislav Nosivskoy Committed by GitHub
Browse files

feat: support prompt_tokens_details in usage (#4239)


Signed-off-by: default avatarVladislav Nosivskoy <vladnosiv@gmail.com>
parent 0f4d7634
......@@ -120,6 +120,7 @@ class StandaloneRouterHandler:
"index": worker_output.get("index"),
"disaggregated_params": worker_output.get("disaggregated_params"),
"extra_args": worker_output.get("extra_args"),
"completion_usage": worker_output.get("completion_usage"),
}
yield llm_engine_output
......
......@@ -229,6 +229,19 @@ class DecodeWorkerHandler(BaseWorkerHandler):
next_total_toks = len(output_ids)
out["token_ids"] = output_ids[num_output_tokens_so_far:]
num_output_tokens_so_far = next_total_toks
if finish_reason:
input_tokens = res["meta_info"]["prompt_tokens"]
completion_tokens = res["meta_info"]["completion_tokens"]
cached_tokens = res["meta_info"]["cached_tokens"]
prefill_prompt_tokens_details = None
if cached_tokens is not None and cached_tokens > 0:
prefill_prompt_tokens_details = {"cached_tokens": cached_tokens}
out["completion_usage"] = {
"prompt_tokens": input_tokens,
"completion_tokens": completion_tokens,
"total_tokens": input_tokens + completion_tokens,
"prompt_tokens_details": prefill_prompt_tokens_details,
}
if not context.is_stopped():
yield out
......
......@@ -242,6 +242,10 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None
# Enable perf metrics so prompt_tokens_details can be returned
if hasattr(default_sampling_params, "return_perf_metrics"):
default_sampling_params.return_perf_metrics = True
model_input = ModelInput.Tokens
# Set model type based on disaggregation mode for unified frontend support
......@@ -356,6 +360,7 @@ async def init(runtime: DistributedRuntime, config: Config):
connector=connector,
runtime=runtime, # Pass runtime for graceful shutdown
metrics_collector=metrics_collector,
kv_block_size=config.kv_block_size,
)
# Register the model with runtime config
......
......@@ -72,6 +72,7 @@ class RequestHandlerConfig:
DistributedRuntime
] = None # DistributedRuntime reference for graceful shutdown
metrics_collector: Optional[Any] = None # TensorRT-LLM MetricsCollector
kv_block_size: int = 32
class HandlerBase:
......@@ -92,6 +93,7 @@ class HandlerBase:
self.connector = config.connector
# Store runtime reference for graceful shutdown
self.runtime = config.runtime
self.kv_block_size: int = config.kv_block_size
def check_error(self, result: dict):
"""
......@@ -208,11 +210,13 @@ class HandlerBase:
request["stop_conditions"]["max_tokens"] = 1
disaggregated_params = LlmDisaggregatedParams(request_type="context_only")
if "disaggregated_params" in request:
if "prefill_result" in request:
if self.disaggregation_mode == DisaggregationMode.PREFILL:
raise ValueError("Cannot provide disaggregated_params in prefill mode")
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**request["disaggregated_params"])
DisaggregatedParams(
**request["prefill_result"].get("disaggregated_params")
)
)
disaggregated_params.request_type = "generation_only"
......@@ -258,6 +262,11 @@ class HandlerBase:
adapters = create_trtllm_adapters(processors)
sampling_params.logits_processor = adapters
prefill_result = request.get("prefill_result")
prefill_prompt_tokens_details = (
prefill_result.get("prompt_tokens_details") if prefill_result else None
)
try:
# NEW: Updated engine call to include multimodal data
generation_result = self.engine.llm.generate_async(
......@@ -298,6 +307,34 @@ class HandlerBase:
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)
if out.get("finish_reason"):
num_input_tokens = len(request.get("token_ids", []))
prompt_tokens_details = None
if prefill_prompt_tokens_details:
prompt_tokens_details = prefill_prompt_tokens_details
else:
if output.request_perf_metrics is not None:
kv_cache_metrics = (
output.request_perf_metrics.kv_cache_metrics
)
cached_tokens = min(
num_input_tokens,
kv_cache_metrics.num_reused_blocks
* self.kv_block_size,
)
if cached_tokens > 0:
prompt_tokens_details = {
"cached_tokens": int(cached_tokens),
}
out["completion_usage"] = {
"prompt_tokens": int(num_input_tokens),
"completion_tokens": int(next_total_toks),
"total_tokens": int(num_input_tokens + next_total_toks),
"prompt_tokens_details": prompt_tokens_details,
}
if res.finished and not out.get("finish_reason"):
out["finish_reason"] = "unknown"
logging.warning(
......
......@@ -10,6 +10,7 @@ from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final
from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.exceptions import EngineDeadError
......@@ -174,6 +175,28 @@ class BaseWorkerHandler(ABC):
return vllm_mm_data if vllm_mm_data else None
@staticmethod
def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]:
return {
"prompt_tokens": (
len(request_output.prompt_token_ids)
if request_output.prompt_token_ids
else None
),
"completion_tokens": len(request_output.outputs[0].token_ids),
"total_tokens": (
len(request_output.prompt_token_ids)
+ len(request_output.outputs[0].token_ids)
if request_output.prompt_token_ids
else None
),
"prompt_tokens_details": (
{"cached_tokens": request_output.num_cached_tokens}
if request_output.num_cached_tokens
else None
),
}
async def generate_tokens(
self, prompt, sampling_params, request_id, data_parallel_rank=None
):
......@@ -199,6 +222,11 @@ class BaseWorkerHandler(ABC):
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
out[
"completion_usage"
] = BaseWorkerHandler._build_completion_usage(
request_output=res
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
......@@ -241,18 +269,24 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Build sampling params from request
sampling_params = build_sampling_params(request, self.default_sampling_params)
# Extract disaggregated_params from request (set by prefill router in Rust frontend)
disaggregated_params = request.get("disaggregated_params")
if disaggregated_params:
# Prefill was performed - use the disaggregated params
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args["kv_transfer_params"] = disaggregated_params.get(
prefill_result = request.get("prefill_result")
if prefill_result and isinstance(prefill_result, dict):
kv_params = prefill_result.get("disaggregated_params", {}).get(
"kv_transfer_params"
)
else:
kv_params = None
if kv_params is not None:
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args["kv_transfer_params"] = kv_params
logger.debug(
f"Using disaggregated params from prefill for request {request_id}"
)
prefill_prompt_tokens_details = (
prefill_result.get("prompt_tokens_details") if prefill_result else None
)
dp_rank = request.get("dp_rank", None)
......@@ -261,6 +295,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
async for tok in self.generate_tokens(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
):
if prefill_result is not None and "completion_usage" in tok:
tok["completion_usage"][
"prompt_tokens_details"
] = prefill_prompt_tokens_details
yield tok
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
......@@ -325,6 +363,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
if res.kv_transfer_params
else None
),
"completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res
),
}
yield output
......
......@@ -242,6 +242,7 @@ impl
finish_reason: data.finish_reason,
//mdcsum: mdcsum.clone(),
index: data.index,
completion_usage: data.completion_usage,
})
})
});
......
......@@ -21,6 +21,7 @@ use crate::{
discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::PrefillResult,
};
/// Errors that can occur during prefill routing
......@@ -175,11 +176,11 @@ impl PrefillRouter {
Ok(())
}
/// Call the prefill router and extract disaggregated_params
/// Call the prefill router and extract structured prefill result
async fn call_prefill(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<serde_json::Value, PrefillError> {
) -> Result<PrefillResult, PrefillError> {
// Get the prefill router, error if not activated
let Some(prefill_router) = self.prefill_router.get() else {
return Err(PrefillError::NotActivated);
......@@ -203,7 +204,22 @@ impl PrefillRouter {
));
};
while prefill_response.next().await.is_some() {}
let mut prompt_tokens_details = first_output
.data
.as_ref()
.and_then(|o| o.completion_usage.as_ref())
.and_then(|u| u.prompt_tokens_details.clone());
while let Some(next) = prefill_response.next().await {
if let Some(o) = next.data.as_ref()
&& prompt_tokens_details.is_none()
{
prompt_tokens_details = o
.completion_usage
.as_ref()
.and_then(|u| u.prompt_tokens_details.clone());
}
}
if let Some(err) = first_output.err() {
return Err(PrefillError::PrefillError(format!(
......@@ -223,7 +239,10 @@ impl PrefillRouter {
));
};
Ok(disaggregated_params)
Ok(PrefillResult {
disaggregated_params,
prompt_tokens_details,
})
}
}
......@@ -267,12 +286,12 @@ impl
// Attempt prefill and handle results
match self.call_prefill(prefill_request).await {
Ok(disaggregated_params) => {
Ok(prefill_result) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode");
// Update request with disaggregated_params and router config
let mut decode_req = req;
decode_req.disaggregated_params = Some(disaggregated_params);
// Update request with prefill result
decode_req.prefill_result = Some(prefill_result.clone());
// Restore original max_tokens for decode
decode_req.stop_conditions.max_tokens = original_max_tokens;
......
......@@ -219,6 +219,7 @@ mod tests {
index: None,
disaggregated_params: None,
extra_args: None,
completion_usage: None,
})
}
......
......@@ -308,6 +308,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
None
},
extra_args: None,
completion_usage: None,
};
if signal.completed && token_count < max_output_tokens {
......
......@@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
pub use super::FinishReason;
pub use super::preprocessor::PreprocessedRequest;
use crate::protocols::TokenIdType;
use dynamo_async_openai::types::CompletionUsage;
use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>;
......@@ -48,6 +49,10 @@ pub struct BackendOutput {
// Index field for batch requests to match OpenAI format
pub index: Option<u32>,
// Token usage information
#[serde(default, skip_serializing_if = "Option::is_none")]
pub completion_usage: Option<CompletionUsage>,
}
/// The LLM engine and backnd with manage it's own state, specifically translating how a
......@@ -92,6 +97,10 @@ pub struct LLMEngineOutput {
/// Additional arguments for extensibility
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>,
// Token usage information
#[serde(default, skip_serializing_if = "Option::is_none")]
pub completion_usage: Option<CompletionUsage>,
}
impl LLMEngineOutput {
......@@ -107,6 +116,7 @@ impl LLMEngineOutput {
index: None,
disaggregated_params: None,
extra_args: None,
completion_usage: None,
}
}
......@@ -122,6 +132,7 @@ impl LLMEngineOutput {
index: None,
disaggregated_params: None,
extra_args: None,
completion_usage: None,
}
}
......@@ -137,6 +148,7 @@ impl LLMEngineOutput {
index: None,
disaggregated_params: None,
extra_args: None,
completion_usage: None,
}
}
......@@ -152,6 +164,7 @@ impl LLMEngineOutput {
index: None,
disaggregated_params: None,
extra_args: None,
completion_usage: None,
}
}
}
......
......@@ -8,6 +8,15 @@ use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride;
use crate::protocols::TokenIdType;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PrefillResult {
/// Disaggregated execution parameters
pub disaggregated_params: serde_json::Value,
/// Prompt token details produced during prefill
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<dynamo_async_openai::types::PromptTokensDetails>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum MultimodalData {
Url(url::Url),
......@@ -69,10 +78,10 @@ pub struct PreprocessedRequest {
#[builder(default)]
pub router_config_override: Option<RouterConfigOverride>,
/// Disaggregated execution parameters (for prefill/decode separation)
/// Structured prefill result
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>,
pub prefill_result: Option<PrefillResult>,
/// Data parallel rank for the request (used with data parallelism)
#[builder(default)]
......
......@@ -316,6 +316,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if let Some(prompt_details) = delta
.completion_usage
.as_ref()
.and_then(|usage| usage.prompt_tokens_details.as_ref())
{
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
}
let logprobs = self.create_logprobs(
......
......@@ -238,6 +238,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if let Some(prompt_details) = delta
.completion_usage
.as_ref()
.and_then(|usage| usage.prompt_tokens_details.as_ref())
{
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
}
let logprobs = self.create_logprobs(
......
......@@ -7,12 +7,17 @@ use dynamo_async_openai::types::{
ChatCompletionRequestUserMessageContent, ChatCompletionStreamOptions,
CreateChatCompletionRequest,
};
use dynamo_async_openai::types::{
CompletionUsage as AoaiCompletionUsage, CreateCompletionRequestArgs, Prompt,
PromptTokensDetails,
};
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::common::llm_backend::{BackendOutput, FinishReason};
use dynamo_llm::protocols::openai::ParsingOptions;
use dynamo_llm::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, aggregator::ChatCompletionAggregator,
};
use dynamo_llm::protocols::openai::completions::NvCreateCompletionRequest;
use dynamo_runtime::engine::{AsyncEngineContext, AsyncEngineStream};
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
......@@ -82,8 +87,17 @@ impl AsyncEngineContext for MockContext {
fn create_mock_backend_stream(
ctx: Arc<dyn AsyncEngineContext>,
) -> Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>> {
let outputs = vec![
// First chunk with "Hello"
let outputs = build_backend_outputs_with_cached_tokens(None);
let stream = stream::iter(outputs.into_iter().map(Annotated::from_data));
use dynamo_runtime::engine::ResponseStream;
ResponseStream::new(Box::pin(stream), ctx)
}
/// Build three backend outputs: "Hello", " world", "!" with optional cached_tokens on the final chunk
fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<BackendOutput> {
vec![
BackendOutput {
token_ids: vec![15339],
tokens: vec![Some("Hello".to_string())],
......@@ -93,8 +107,8 @@ fn create_mock_backend_stream(
top_logprobs: None,
finish_reason: None,
index: Some(0),
completion_usage: None,
},
// Second chunk with " world"
BackendOutput {
token_ids: vec![1917],
tokens: vec![Some(" world".to_string())],
......@@ -104,8 +118,8 @@ fn create_mock_backend_stream(
top_logprobs: None,
finish_reason: None,
index: Some(0),
completion_usage: None,
},
// Third chunk with "!" and finish_reason
BackendOutput {
token_ids: vec![0],
tokens: vec![Some("!".to_string())],
......@@ -115,11 +129,27 @@ fn create_mock_backend_stream(
top_logprobs: None,
finish_reason: Some(FinishReason::Stop),
index: Some(0),
completion_usage: cached_tokens.map(|ct| AoaiCompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: Some(PromptTokensDetails {
audio_tokens: None,
cached_tokens: Some(ct),
}),
completion_tokens_details: None,
}),
},
];
]
}
/// Create a backend stream from standard outputs with optional cached_tokens in the final chunk
fn create_backend_stream_with_cached_tokens(
ctx: Arc<dyn AsyncEngineContext>,
cached_tokens: Option<u32>,
) -> Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>> {
let outputs = build_backend_outputs_with_cached_tokens(cached_tokens);
let stream = stream::iter(outputs.into_iter().map(Annotated::from_data));
use dynamo_runtime::engine::ResponseStream;
ResponseStream::new(Box::pin(stream), ctx)
}
......@@ -308,6 +338,31 @@ async fn test_streaming_with_usage_false() {
}
}
/// Helper to create a completion request with optional stream_options
fn create_cmpl_request(include_usage: Option<bool>, stream: bool) -> NvCreateCompletionRequest {
let inner = {
let mut builder = CreateCompletionRequestArgs::default();
builder
.model("test-model")
.prompt(Prompt::String("Hello".to_string()))
.stream(stream);
if let Some(include) = include_usage {
builder.stream_options(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: include,
});
}
builder.build().unwrap()
};
NvCreateCompletionRequest {
inner,
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
}
}
/// Helper to create a non-streaming chat completion request
fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
......@@ -404,3 +459,195 @@ async fn test_nonstreaming_has_usage_field() {
"Total tokens should equal prompt_tokens + completion_tokens"
);
}
#[tokio::test]
async fn test_cmpl_streaming_with_usage_true_no_backend_usage() {
// Completions: stream=true, include_usage=true, but backend does not send completion_usage
let request = create_cmpl_request(Some(true), true);
let request_id = "cmpl-usage-none-1".to_string();
let response_generator = Box::new(request.response_generator(request_id));
// Mock backend stream (no completion_usage in any chunk)
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
// Expect 3 content chunks + 1 usage-only chunk
assert_eq!(chunks.len(), 4, "Should have 3 content + 1 usage chunk");
// First 3 chunks: usage must be None
for (i, chunk) in chunks.iter().take(3).enumerate() {
if let Some(resp) = &chunk.data {
assert!(
resp.inner.usage.is_none(),
"Content chunk {} should have usage: None",
i
);
assert!(
!resp.inner.choices.is_empty(),
"Content chunk {} should have choices",
i
);
}
}
// Final usage chunk: usage present with counts; prompt_tokens_details None (no backend usage)
if let Some(final_resp) = &chunks[3].data {
assert!(
final_resp.inner.choices.is_empty(),
"Usage-only chunk must have empty choices"
);
let usage = final_resp
.inner
.usage
.as_ref()
.expect("Usage must be present");
assert_eq!(
usage.completion_tokens, 3,
"Aggregated completion tokens should be 3"
);
assert!(
usage.prompt_tokens_details.is_none(),
"prompt_tokens_details should be None when backend does not send usage"
);
} else {
panic!("Final chunk should be present");
}
}
#[tokio::test]
async fn test_cmpl_streaming_with_cached_tokens_propagation() {
// Completions: include_usage=true, backend provides cached_tokens -> must propagate
let request = create_cmpl_request(Some(true), true);
let request_id = "cmpl-usage-cached-1".to_string();
let mut response_generator = Box::new(request.response_generator(request_id));
// Build a backend stream where the final chunk carries completion_usage with cached_tokens
let ctx = Arc::new(MockContext::new());
let backend_stream = create_backend_stream_with_cached_tokens(ctx.clone(), Some(7));
// Align ISL so total usage gets computed correctly
response_generator.update_isl(0);
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
// Expect 4 chunks total
assert_eq!(chunks.len(), 4, "Should have 3 content + 1 usage chunk");
// Final usage chunk should include cached_tokens propagated
if let Some(final_resp) = &chunks[3].data {
let usage = final_resp
.inner
.usage
.as_ref()
.expect("Usage must be present on final chunk");
let cached = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
assert_eq!(
cached,
Some(7),
"cached_tokens must propagate to final usage chunk"
);
} else {
panic!("Final chunk should be present");
}
}
#[tokio::test]
async fn test_chat_streaming_with_cached_tokens_propagation() {
// Chat Completions: include_usage=true, backend provides cached_tokens -> must propagate
let request = create_chat_request(Some(true));
let request_id = "chat-usage-cached-1".to_string();
let mut response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_backend_stream_with_cached_tokens(ctx.clone(), Some(5));
// Align ISL if needed
response_generator.update_isl(0);
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
assert_eq!(chunks.len(), 4, "Should have 3 content + 1 usage chunk");
if let Some(final_resp) = &chunks[3].data {
let usage = final_resp.usage.as_ref().expect("Usage must be present");
let cached = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
assert_eq!(
cached,
Some(5),
"cached_tokens must propagate for chat completions"
);
} else {
panic!("Final chunk should be present");
}
}
#[tokio::test]
async fn test_cmpl_nonstreaming_has_usage_and_cached_tokens() {
// Non-streaming completions must include usage in final aggregated response and propagate cached_tokens
let mut request = create_cmpl_request(None, false);
// Simulate preprocessor behavior for non-streaming
let original_stream_flag = request.inner.stream.unwrap_or(false);
request.enable_usage_for_nonstreaming(original_stream_flag);
let request_id = "cmpl-nonstream-usage".to_string();
let response_generator = Box::new(request.response_generator(request_id));
// Mock backend stream with 3 chunks, last carries completion_usage with cached_tokens
let ctx = Arc::new(MockContext::new());
let backend_stream = create_backend_stream_with_cached_tokens(ctx.clone(), Some(9));
// Transform to OpenAI completion stream
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
// Aggregate into a single non-streaming response
let parsing = ParsingOptions::default();
let result =
dynamo_llm::protocols::openai::completions::NvCreateCompletionResponse::from_annotated_stream(
transformed_stream,
parsing,
)
.await;
assert!(result.is_ok(), "Aggregation should succeed");
let resp = result.unwrap();
let usage = resp
.inner
.usage
.expect("usage must be present for non-streaming");
assert_eq!(
usage.completion_tokens, 3,
"completion_tokens must aggregate"
);
let cached = usage.prompt_tokens_details.and_then(|d| d.cached_tokens);
assert_eq!(
cached,
Some(9),
"cached_tokens must propagate to non-streaming response"
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment