Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
be9d6b2b
Unverified
Commit
be9d6b2b
authored
Nov 19, 2025
by
Vladislav Nosivskoy
Committed by
GitHub
Nov 18, 2025
Browse files
feat: support prompt_tokens_details in usage (#4239)
Signed-off-by:
Vladislav Nosivskoy
<
vladnosiv@gmail.com
>
parent
0f4d7634
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
432 additions
and
24 deletions
+432
-24
components/src/dynamo/router/__main__.py
components/src/dynamo/router/__main__.py
+1
-0
components/src/dynamo/sglang/request_handlers/llm/decode_handler.py
.../src/dynamo/sglang/request_handlers/llm/decode_handler.py
+13
-0
components/src/dynamo/trtllm/main.py
components/src/dynamo/trtllm/main.py
+5
-0
components/src/dynamo/trtllm/request_handlers/handler_base.py
...onents/src/dynamo/trtllm/request_handlers/handler_base.py
+39
-2
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+48
-7
lib/llm/src/backend.rs
lib/llm/src/backend.rs
+1
-0
lib/llm/src/kv_router/prefill_router.rs
lib/llm/src/kv_router/prefill_router.rs
+26
-7
lib/llm/src/migration.rs
lib/llm/src/migration.rs
+1
-0
lib/llm/src/mocker/engine.rs
lib/llm/src/mocker/engine.rs
+1
-0
lib/llm/src/protocols/common/llm_backend.rs
lib/llm/src/protocols/common/llm_backend.rs
+13
-0
lib/llm/src/protocols/common/preprocessor.rs
lib/llm/src/protocols/common/preprocessor.rs
+11
-2
lib/llm/src/protocols/openai/chat_completions/delta.rs
lib/llm/src/protocols/openai/chat_completions/delta.rs
+10
-0
lib/llm/src/protocols/openai/completions/delta.rs
lib/llm/src/protocols/openai/completions/delta.rs
+10
-0
lib/llm/tests/test_streaming_usage.rs
lib/llm/tests/test_streaming_usage.rs
+253
-6
No files found.
components/src/dynamo/router/__main__.py
View file @
be9d6b2b
...
...
@@ -120,6 +120,7 @@ class StandaloneRouterHandler:
"index"
:
worker_output
.
get
(
"index"
),
"disaggregated_params"
:
worker_output
.
get
(
"disaggregated_params"
),
"extra_args"
:
worker_output
.
get
(
"extra_args"
),
"completion_usage"
:
worker_output
.
get
(
"completion_usage"
),
}
yield
llm_engine_output
...
...
components/src/dynamo/sglang/request_handlers/llm/decode_handler.py
View file @
be9d6b2b
...
...
@@ -229,6 +229,19 @@ class DecodeWorkerHandler(BaseWorkerHandler):
next_total_toks
=
len
(
output_ids
)
out
[
"token_ids"
]
=
output_ids
[
num_output_tokens_so_far
:]
num_output_tokens_so_far
=
next_total_toks
if
finish_reason
:
input_tokens
=
res
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
res
[
"meta_info"
][
"completion_tokens"
]
cached_tokens
=
res
[
"meta_info"
][
"cached_tokens"
]
prefill_prompt_tokens_details
=
None
if
cached_tokens
is
not
None
and
cached_tokens
>
0
:
prefill_prompt_tokens_details
=
{
"cached_tokens"
:
cached_tokens
}
out
[
"completion_usage"
]
=
{
"prompt_tokens"
:
input_tokens
,
"completion_tokens"
:
completion_tokens
,
"total_tokens"
:
input_tokens
+
completion_tokens
,
"prompt_tokens_details"
:
prefill_prompt_tokens_details
,
}
if
not
context
.
is_stopped
():
yield
out
...
...
components/src/dynamo/trtllm/main.py
View file @
be9d6b2b
...
...
@@ -242,6 +242,10 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params
=
SamplingParams
()
default_sampling_params
.
_setup
(
tokenizer
)
default_sampling_params
.
stop
=
None
# Enable perf metrics so prompt_tokens_details can be returned
if
hasattr
(
default_sampling_params
,
"return_perf_metrics"
):
default_sampling_params
.
return_perf_metrics
=
True
model_input
=
ModelInput
.
Tokens
# Set model type based on disaggregation mode for unified frontend support
...
...
@@ -356,6 +360,7 @@ async def init(runtime: DistributedRuntime, config: Config):
connector
=
connector
,
runtime
=
runtime
,
# Pass runtime for graceful shutdown
metrics_collector
=
metrics_collector
,
kv_block_size
=
config
.
kv_block_size
,
)
# Register the model with runtime config
...
...
components/src/dynamo/trtllm/request_handlers/handler_base.py
View file @
be9d6b2b
...
...
@@ -72,6 +72,7 @@ class RequestHandlerConfig:
DistributedRuntime
]
=
None
# DistributedRuntime reference for graceful shutdown
metrics_collector
:
Optional
[
Any
]
=
None
# TensorRT-LLM MetricsCollector
kv_block_size
:
int
=
32
class
HandlerBase
:
...
...
@@ -92,6 +93,7 @@ class HandlerBase:
self
.
connector
=
config
.
connector
# Store runtime reference for graceful shutdown
self
.
runtime
=
config
.
runtime
self
.
kv_block_size
:
int
=
config
.
kv_block_size
def
check_error
(
self
,
result
:
dict
):
"""
...
...
@@ -208,11 +210,13 @@ class HandlerBase:
request
[
"stop_conditions"
][
"max_tokens"
]
=
1
disaggregated_params
=
LlmDisaggregatedParams
(
request_type
=
"context_only"
)
if
"
disaggregated_params
"
in
request
:
if
"
prefill_result
"
in
request
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
raise
ValueError
(
"Cannot provide disaggregated_params in prefill mode"
)
disaggregated_params
=
DisaggregatedParamsCodec
.
decode
(
DisaggregatedParams
(
**
request
[
"disaggregated_params"
])
DisaggregatedParams
(
**
request
[
"prefill_result"
].
get
(
"disaggregated_params"
)
)
)
disaggregated_params
.
request_type
=
"generation_only"
...
...
@@ -258,6 +262,11 @@ class HandlerBase:
adapters
=
create_trtllm_adapters
(
processors
)
sampling_params
.
logits_processor
=
adapters
prefill_result
=
request
.
get
(
"prefill_result"
)
prefill_prompt_tokens_details
=
(
prefill_result
.
get
(
"prompt_tokens_details"
)
if
prefill_result
else
None
)
try
:
# NEW: Updated engine call to include multimodal data
generation_result
=
self
.
engine
.
llm
.
generate_async
(
...
...
@@ -298,6 +307,34 @@ class HandlerBase:
DisaggregatedParamsCodec
.
encode
(
output
.
disaggregated_params
)
)
if
out
.
get
(
"finish_reason"
):
num_input_tokens
=
len
(
request
.
get
(
"token_ids"
,
[]))
prompt_tokens_details
=
None
if
prefill_prompt_tokens_details
:
prompt_tokens_details
=
prefill_prompt_tokens_details
else
:
if
output
.
request_perf_metrics
is
not
None
:
kv_cache_metrics
=
(
output
.
request_perf_metrics
.
kv_cache_metrics
)
cached_tokens
=
min
(
num_input_tokens
,
kv_cache_metrics
.
num_reused_blocks
*
self
.
kv_block_size
,
)
if
cached_tokens
>
0
:
prompt_tokens_details
=
{
"cached_tokens"
:
int
(
cached_tokens
),
}
out
[
"completion_usage"
]
=
{
"prompt_tokens"
:
int
(
num_input_tokens
),
"completion_tokens"
:
int
(
next_total_toks
),
"total_tokens"
:
int
(
num_input_tokens
+
next_total_toks
),
"prompt_tokens_details"
:
prompt_tokens_details
,
}
if
res
.
finished
and
not
out
.
get
(
"finish_reason"
):
out
[
"finish_reason"
]
=
"unknown"
logging
.
warning
(
...
...
components/src/dynamo/vllm/handlers.py
View file @
be9d6b2b
...
...
@@ -10,6 +10,7 @@ from contextlib import asynccontextmanager
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
Final
from
vllm.inputs
import
TokensPrompt
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.engine.exceptions
import
EngineDeadError
...
...
@@ -174,6 +175,28 @@ class BaseWorkerHandler(ABC):
return
vllm_mm_data
if
vllm_mm_data
else
None
@
staticmethod
def
_build_completion_usage
(
request_output
:
RequestOutput
)
->
Dict
[
str
,
Any
]:
return
{
"prompt_tokens"
:
(
len
(
request_output
.
prompt_token_ids
)
if
request_output
.
prompt_token_ids
else
None
),
"completion_tokens"
:
len
(
request_output
.
outputs
[
0
].
token_ids
),
"total_tokens"
:
(
len
(
request_output
.
prompt_token_ids
)
+
len
(
request_output
.
outputs
[
0
].
token_ids
)
if
request_output
.
prompt_token_ids
else
None
),
"prompt_tokens_details"
:
(
{
"cached_tokens"
:
request_output
.
num_cached_tokens
}
if
request_output
.
num_cached_tokens
else
None
),
}
async
def
generate_tokens
(
self
,
prompt
,
sampling_params
,
request_id
,
data_parallel_rank
=
None
):
...
...
@@ -199,6 +222,11 @@ class BaseWorkerHandler(ABC):
out
=
{
"token_ids"
:
output
.
token_ids
[
num_output_tokens_so_far
:]}
if
output
.
finish_reason
:
out
[
"finish_reason"
]
=
output
.
finish_reason
out
[
"completion_usage"
]
=
BaseWorkerHandler
.
_build_completion_usage
(
request_output
=
res
)
if
output
.
stop_reason
:
out
[
"stop_reason"
]
=
output
.
stop_reason
yield
out
...
...
@@ -241,18 +269,24 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Build sampling params from request
sampling_params
=
build_sampling_params
(
request
,
self
.
default_sampling_params
)
# Extract disaggregated_params from request (set by prefill router in Rust frontend)
disaggregated_params
=
request
.
get
(
"disaggregated_params"
)
if
disaggregated_params
:
# Prefill was performed - use the disaggregated params
if
sampling_params
.
extra_args
is
None
:
sampling_params
.
extra_args
=
{}
sampling_params
.
extra_args
[
"kv_transfer_params"
]
=
disaggregated_params
.
get
(
prefill_result
=
request
.
get
(
"prefill_result"
)
if
prefill_result
and
isinstance
(
prefill_result
,
dict
):
kv_params
=
prefill_result
.
get
(
"disaggregated_params"
,
{}).
get
(
"kv_transfer_params"
)
else
:
kv_params
=
None
if
kv_params
is
not
None
:
if
sampling_params
.
extra_args
is
None
:
sampling_params
.
extra_args
=
{}
sampling_params
.
extra_args
[
"kv_transfer_params"
]
=
kv_params
logger
.
debug
(
f
"Using disaggregated params from prefill for request
{
request_id
}
"
)
prefill_prompt_tokens_details
=
(
prefill_result
.
get
(
"prompt_tokens_details"
)
if
prefill_result
else
None
)
dp_rank
=
request
.
get
(
"dp_rank"
,
None
)
...
...
@@ -261,6 +295,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
async
for
tok
in
self
.
generate_tokens
(
prompt
,
sampling_params
,
request_id
,
data_parallel_rank
=
dp_rank
):
if
prefill_result
is
not
None
and
"completion_usage"
in
tok
:
tok
[
"completion_usage"
][
"prompt_tokens_details"
]
=
prefill_prompt_tokens_details
yield
tok
except
EngineDeadError
as
e
:
logger
.
error
(
f
"vLLM EngineDeadError:
{
e
}
"
)
...
...
@@ -325,6 +363,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
if
res
.
kv_transfer_params
else
None
),
"completion_usage"
:
BaseWorkerHandler
.
_build_completion_usage
(
request_output
=
res
),
}
yield
output
...
...
lib/llm/src/backend.rs
View file @
be9d6b2b
...
...
@@ -242,6 +242,7 @@ impl
finish_reason
:
data
.finish_reason
,
//mdcsum: mdcsum.clone(),
index
:
data
.index
,
completion_usage
:
data
.completion_usage
,
})
})
});
...
...
lib/llm/src/kv_router/prefill_router.rs
View file @
be9d6b2b
...
...
@@ -21,6 +21,7 @@ use crate::{
discovery
::
ModelManager
,
kv_router
::{
KvPushRouter
,
KvRouterConfig
,
RouterConfigOverride
},
protocols
::
common
::
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
},
protocols
::
common
::
preprocessor
::
PrefillResult
,
};
/// Errors that can occur during prefill routing
...
...
@@ -175,11 +176,11 @@ impl PrefillRouter {
Ok
(())
}
/// Call the prefill router and extract
disaggregated_params
/// Call the prefill router and extract
structured prefill result
async
fn
call_prefill
(
&
self
,
request
:
SingleIn
<
PreprocessedRequest
>
,
)
->
Result
<
serde_json
::
Value
,
PrefillError
>
{
)
->
Result
<
PrefillResult
,
PrefillError
>
{
// Get the prefill router, error if not activated
let
Some
(
prefill_router
)
=
self
.prefill_router
.get
()
else
{
return
Err
(
PrefillError
::
NotActivated
);
...
...
@@ -203,7 +204,22 @@ impl PrefillRouter {
));
};
while
prefill_response
.next
()
.await
.is_some
()
{}
let
mut
prompt_tokens_details
=
first_output
.data
.as_ref
()
.and_then
(|
o
|
o
.completion_usage
.as_ref
())
.and_then
(|
u
|
u
.prompt_tokens_details
.clone
());
while
let
Some
(
next
)
=
prefill_response
.next
()
.await
{
if
let
Some
(
o
)
=
next
.data
.as_ref
()
&&
prompt_tokens_details
.is_none
()
{
prompt_tokens_details
=
o
.completion_usage
.as_ref
()
.and_then
(|
u
|
u
.prompt_tokens_details
.clone
());
}
}
if
let
Some
(
err
)
=
first_output
.err
()
{
return
Err
(
PrefillError
::
PrefillError
(
format!
(
...
...
@@ -223,7 +239,10 @@ impl PrefillRouter {
));
};
Ok
(
disaggregated_params
)
Ok
(
PrefillResult
{
disaggregated_params
,
prompt_tokens_details
,
})
}
}
...
...
@@ -267,12 +286,12 @@ impl
// Attempt prefill and handle results
match
self
.call_prefill
(
prefill_request
)
.await
{
Ok
(
disaggregated_params
)
=>
{
Ok
(
prefill_result
)
=>
{
tracing
::
debug!
(
"Prefill succeeded, using disaggregated params for decode"
);
// Update request with disaggregated_params and router config
let
mut
decode_req
=
req
;
decode_req
.disaggregated_params
=
Some
(
disaggregated_params
);
// Update request with prefill result
decode_req
.prefill_result
=
Some
(
prefill_result
.clone
());
// Restore original max_tokens for decode
decode_req
.stop_conditions.max_tokens
=
original_max_tokens
;
...
...
lib/llm/src/migration.rs
View file @
be9d6b2b
...
...
@@ -219,6 +219,7 @@ mod tests {
index
:
None
,
disaggregated_params
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
})
}
...
...
lib/llm/src/mocker/engine.rs
View file @
be9d6b2b
...
...
@@ -308,6 +308,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
None
},
extra_args
:
None
,
completion_usage
:
None
,
};
if
signal
.completed
&&
token_count
<
max_output_tokens
{
...
...
lib/llm/src/protocols/common/llm_backend.rs
View file @
be9d6b2b
...
...
@@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
pub
use
super
::
FinishReason
;
pub
use
super
::
preprocessor
::
PreprocessedRequest
;
use
crate
::
protocols
::
TokenIdType
;
use
dynamo_async_openai
::
types
::
CompletionUsage
;
use
dynamo_runtime
::
protocols
::
maybe_error
::
MaybeError
;
pub
type
TokenType
=
Option
<
String
>
;
...
...
@@ -48,6 +49,10 @@ pub struct BackendOutput {
// Index field for batch requests to match OpenAI format
pub
index
:
Option
<
u32
>
,
// Token usage information
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
completion_usage
:
Option
<
CompletionUsage
>
,
}
/// The LLM engine and backnd with manage it's own state, specifically translating how a
...
...
@@ -92,6 +97,10 @@ pub struct LLMEngineOutput {
/// Additional arguments for extensibility
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
extra_args
:
Option
<
serde_json
::
Value
>
,
// Token usage information
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
completion_usage
:
Option
<
CompletionUsage
>
,
}
impl
LLMEngineOutput
{
...
...
@@ -107,6 +116,7 @@ impl LLMEngineOutput {
index
:
None
,
disaggregated_params
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
}
}
...
...
@@ -122,6 +132,7 @@ impl LLMEngineOutput {
index
:
None
,
disaggregated_params
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
}
}
...
...
@@ -137,6 +148,7 @@ impl LLMEngineOutput {
index
:
None
,
disaggregated_params
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
}
}
...
...
@@ -152,6 +164,7 @@ impl LLMEngineOutput {
index
:
None
,
disaggregated_params
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
}
}
}
...
...
lib/llm/src/protocols/common/preprocessor.rs
View file @
be9d6b2b
...
...
@@ -8,6 +8,15 @@ use super::{OutputOptions, SamplingOptions, StopConditions};
use
crate
::
kv_router
::
RouterConfigOverride
;
use
crate
::
protocols
::
TokenIdType
;
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
PrefillResult
{
/// Disaggregated execution parameters
pub
disaggregated_params
:
serde_json
::
Value
,
/// Prompt token details produced during prefill
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
prompt_tokens_details
:
Option
<
dynamo_async_openai
::
types
::
PromptTokensDetails
>
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
enum
MultimodalData
{
Url
(
url
::
Url
),
...
...
@@ -69,10 +78,10 @@ pub struct PreprocessedRequest {
#[builder(default)]
pub
router_config_override
:
Option
<
RouterConfigOverride
>
,
///
Disaggregated execution parameters (for prefill/decode separation)
///
Structured prefill result
#[builder(default)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
disaggregated_params
:
Option
<
serde_json
::
Value
>
,
pub
prefill_result
:
Option
<
PrefillResult
>
,
/// Data parallel rank for the request (used with data parallelism)
#[builder(default)]
...
...
lib/llm/src/protocols/openai/chat_completions/delta.rs
View file @
be9d6b2b
...
...
@@ -316,6 +316,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
.expect
(
"token_ids length exceeds u32::MAX"
);
self
.usage.completion_tokens
+=
token_length
;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if
let
Some
(
prompt_details
)
=
delta
.completion_usage
.as_ref
()
.and_then
(|
usage
|
usage
.prompt_tokens_details
.as_ref
())
{
self
.usage.prompt_tokens_details
=
Some
(
prompt_details
.clone
());
}
}
let
logprobs
=
self
.create_logprobs
(
...
...
lib/llm/src/protocols/openai/completions/delta.rs
View file @
be9d6b2b
...
...
@@ -238,6 +238,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
.expect
(
"token_ids length exceeds u32::MAX"
);
self
.usage.completion_tokens
+=
token_length
;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if
let
Some
(
prompt_details
)
=
delta
.completion_usage
.as_ref
()
.and_then
(|
usage
|
usage
.prompt_tokens_details
.as_ref
())
{
self
.usage.prompt_tokens_details
=
Some
(
prompt_details
.clone
());
}
}
let
logprobs
=
self
.create_logprobs
(
...
...
lib/llm/tests/test_streaming_usage.rs
View file @
be9d6b2b
...
...
@@ -7,12 +7,17 @@ use dynamo_async_openai::types::{
ChatCompletionRequestUserMessageContent
,
ChatCompletionStreamOptions
,
CreateChatCompletionRequest
,
};
use
dynamo_async_openai
::
types
::{
CompletionUsage
as
AoaiCompletionUsage
,
CreateCompletionRequestArgs
,
Prompt
,
PromptTokensDetails
,
};
use
dynamo_llm
::
preprocessor
::
OpenAIPreprocessor
;
use
dynamo_llm
::
protocols
::
common
::
llm_backend
::{
BackendOutput
,
FinishReason
};
use
dynamo_llm
::
protocols
::
openai
::
ParsingOptions
;
use
dynamo_llm
::
protocols
::
openai
::
chat_completions
::{
NvCreateChatCompletionRequest
,
aggregator
::
ChatCompletionAggregator
,
};
use
dynamo_llm
::
protocols
::
openai
::
completions
::
NvCreateCompletionRequest
;
use
dynamo_runtime
::
engine
::{
AsyncEngineContext
,
AsyncEngineStream
};
use
dynamo_runtime
::
protocols
::
annotated
::
Annotated
;
use
futures
::
StreamExt
;
...
...
@@ -82,8 +87,17 @@ impl AsyncEngineContext for MockContext {
fn
create_mock_backend_stream
(
ctx
:
Arc
<
dyn
AsyncEngineContext
>
,
)
->
Pin
<
Box
<
dyn
AsyncEngineStream
<
Annotated
<
BackendOutput
>>>>
{
let
outputs
=
vec!
[
// First chunk with "Hello"
let
outputs
=
build_backend_outputs_with_cached_tokens
(
None
);
let
stream
=
stream
::
iter
(
outputs
.into_iter
()
.map
(
Annotated
::
from_data
));
use
dynamo_runtime
::
engine
::
ResponseStream
;
ResponseStream
::
new
(
Box
::
pin
(
stream
),
ctx
)
}
/// Build three backend outputs: "Hello", " world", "!" with optional cached_tokens on the final chunk
fn
build_backend_outputs_with_cached_tokens
(
cached_tokens
:
Option
<
u32
>
)
->
Vec
<
BackendOutput
>
{
vec!
[
BackendOutput
{
token_ids
:
vec!
[
15339
],
tokens
:
vec!
[
Some
(
"Hello"
.to_string
())],
...
...
@@ -93,8 +107,8 @@ fn create_mock_backend_stream(
top_logprobs
:
None
,
finish_reason
:
None
,
index
:
Some
(
0
),
completion_usage
:
None
,
},
// Second chunk with " world"
BackendOutput
{
token_ids
:
vec!
[
1917
],
tokens
:
vec!
[
Some
(
" world"
.to_string
())],
...
...
@@ -104,8 +118,8 @@ fn create_mock_backend_stream(
top_logprobs
:
None
,
finish_reason
:
None
,
index
:
Some
(
0
),
completion_usage
:
None
,
},
// Third chunk with "!" and finish_reason
BackendOutput
{
token_ids
:
vec!
[
0
],
tokens
:
vec!
[
Some
(
"!"
.to_string
())],
...
...
@@ -115,11 +129,27 @@ fn create_mock_backend_stream(
top_logprobs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Stop
),
index
:
Some
(
0
),
completion_usage
:
cached_tokens
.map
(|
ct
|
AoaiCompletionUsage
{
prompt_tokens
:
0
,
completion_tokens
:
0
,
total_tokens
:
0
,
prompt_tokens_details
:
Some
(
PromptTokensDetails
{
audio_tokens
:
None
,
cached_tokens
:
Some
(
ct
),
}),
completion_tokens_details
:
None
,
}),
},
];
]
}
/// Create a backend stream from standard outputs with optional cached_tokens in the final chunk
fn
create_backend_stream_with_cached_tokens
(
ctx
:
Arc
<
dyn
AsyncEngineContext
>
,
cached_tokens
:
Option
<
u32
>
,
)
->
Pin
<
Box
<
dyn
AsyncEngineStream
<
Annotated
<
BackendOutput
>>>>
{
let
outputs
=
build_backend_outputs_with_cached_tokens
(
cached_tokens
);
let
stream
=
stream
::
iter
(
outputs
.into_iter
()
.map
(
Annotated
::
from_data
));
use
dynamo_runtime
::
engine
::
ResponseStream
;
ResponseStream
::
new
(
Box
::
pin
(
stream
),
ctx
)
}
...
...
@@ -308,6 +338,31 @@ async fn test_streaming_with_usage_false() {
}
}
/// Helper to create a completion request with optional stream_options
fn
create_cmpl_request
(
include_usage
:
Option
<
bool
>
,
stream
:
bool
)
->
NvCreateCompletionRequest
{
let
inner
=
{
let
mut
builder
=
CreateCompletionRequestArgs
::
default
();
builder
.model
(
"test-model"
)
.prompt
(
Prompt
::
String
(
"Hello"
.to_string
()))
.stream
(
stream
);
if
let
Some
(
include
)
=
include_usage
{
builder
.stream_options
(
dynamo_async_openai
::
types
::
ChatCompletionStreamOptions
{
include_usage
:
include
,
});
}
builder
.build
()
.unwrap
()
};
NvCreateCompletionRequest
{
inner
,
common
:
Default
::
default
(),
nvext
:
None
,
metadata
:
None
,
unsupported_fields
:
Default
::
default
(),
}
}
/// Helper to create a non-streaming chat completion request
fn
create_nonstreaming_chat_request
()
->
NvCreateChatCompletionRequest
{
let
messages
=
vec!
[
ChatCompletionRequestMessage
::
User
(
...
...
@@ -404,3 +459,195 @@ async fn test_nonstreaming_has_usage_field() {
"Total tokens should equal prompt_tokens + completion_tokens"
);
}
#[tokio::test]
async
fn
test_cmpl_streaming_with_usage_true_no_backend_usage
()
{
// Completions: stream=true, include_usage=true, but backend does not send completion_usage
let
request
=
create_cmpl_request
(
Some
(
true
),
true
);
let
request_id
=
"cmpl-usage-none-1"
.to_string
();
let
response_generator
=
Box
::
new
(
request
.response_generator
(
request_id
));
// Mock backend stream (no completion_usage in any chunk)
let
ctx
=
Arc
::
new
(
MockContext
::
new
());
let
backend_stream
=
create_mock_backend_stream
(
ctx
.clone
());
// Transform
let
transformed_stream
=
OpenAIPreprocessor
::
transform_postprocessor_stream
(
backend_stream
,
response_generator
,
ctx
.clone
(),
);
let
chunks
:
Vec
<
_
>
=
transformed_stream
.collect
()
.await
;
// Expect 3 content chunks + 1 usage-only chunk
assert_eq!
(
chunks
.len
(),
4
,
"Should have 3 content + 1 usage chunk"
);
// First 3 chunks: usage must be None
for
(
i
,
chunk
)
in
chunks
.iter
()
.take
(
3
)
.enumerate
()
{
if
let
Some
(
resp
)
=
&
chunk
.data
{
assert
!
(
resp
.inner.usage
.is_none
(),
"Content chunk {} should have usage: None"
,
i
);
assert
!
(
!
resp
.inner.choices
.is_empty
(),
"Content chunk {} should have choices"
,
i
);
}
}
// Final usage chunk: usage present with counts; prompt_tokens_details None (no backend usage)
if
let
Some
(
final_resp
)
=
&
chunks
[
3
]
.data
{
assert
!
(
final_resp
.inner.choices
.is_empty
(),
"Usage-only chunk must have empty choices"
);
let
usage
=
final_resp
.inner
.usage
.as_ref
()
.expect
(
"Usage must be present"
);
assert_eq!
(
usage
.completion_tokens
,
3
,
"Aggregated completion tokens should be 3"
);
assert
!
(
usage
.prompt_tokens_details
.is_none
(),
"prompt_tokens_details should be None when backend does not send usage"
);
}
else
{
panic!
(
"Final chunk should be present"
);
}
}
#[tokio::test]
async
fn
test_cmpl_streaming_with_cached_tokens_propagation
()
{
// Completions: include_usage=true, backend provides cached_tokens -> must propagate
let
request
=
create_cmpl_request
(
Some
(
true
),
true
);
let
request_id
=
"cmpl-usage-cached-1"
.to_string
();
let
mut
response_generator
=
Box
::
new
(
request
.response_generator
(
request_id
));
// Build a backend stream where the final chunk carries completion_usage with cached_tokens
let
ctx
=
Arc
::
new
(
MockContext
::
new
());
let
backend_stream
=
create_backend_stream_with_cached_tokens
(
ctx
.clone
(),
Some
(
7
));
// Align ISL so total usage gets computed correctly
response_generator
.update_isl
(
0
);
let
transformed_stream
=
OpenAIPreprocessor
::
transform_postprocessor_stream
(
backend_stream
,
response_generator
,
ctx
.clone
(),
);
let
chunks
:
Vec
<
_
>
=
transformed_stream
.collect
()
.await
;
// Expect 4 chunks total
assert_eq!
(
chunks
.len
(),
4
,
"Should have 3 content + 1 usage chunk"
);
// Final usage chunk should include cached_tokens propagated
if
let
Some
(
final_resp
)
=
&
chunks
[
3
]
.data
{
let
usage
=
final_resp
.inner
.usage
.as_ref
()
.expect
(
"Usage must be present on final chunk"
);
let
cached
=
usage
.prompt_tokens_details
.as_ref
()
.and_then
(|
d
|
d
.cached_tokens
);
assert_eq!
(
cached
,
Some
(
7
),
"cached_tokens must propagate to final usage chunk"
);
}
else
{
panic!
(
"Final chunk should be present"
);
}
}
#[tokio::test]
async
fn
test_chat_streaming_with_cached_tokens_propagation
()
{
// Chat Completions: include_usage=true, backend provides cached_tokens -> must propagate
let
request
=
create_chat_request
(
Some
(
true
));
let
request_id
=
"chat-usage-cached-1"
.to_string
();
let
mut
response_generator
=
Box
::
new
(
request
.response_generator
(
request_id
));
let
ctx
=
Arc
::
new
(
MockContext
::
new
());
let
backend_stream
=
create_backend_stream_with_cached_tokens
(
ctx
.clone
(),
Some
(
5
));
// Align ISL if needed
response_generator
.update_isl
(
0
);
let
transformed_stream
=
OpenAIPreprocessor
::
transform_postprocessor_stream
(
backend_stream
,
response_generator
,
ctx
.clone
(),
);
let
chunks
:
Vec
<
_
>
=
transformed_stream
.collect
()
.await
;
assert_eq!
(
chunks
.len
(),
4
,
"Should have 3 content + 1 usage chunk"
);
if
let
Some
(
final_resp
)
=
&
chunks
[
3
]
.data
{
let
usage
=
final_resp
.usage
.as_ref
()
.expect
(
"Usage must be present"
);
let
cached
=
usage
.prompt_tokens_details
.as_ref
()
.and_then
(|
d
|
d
.cached_tokens
);
assert_eq!
(
cached
,
Some
(
5
),
"cached_tokens must propagate for chat completions"
);
}
else
{
panic!
(
"Final chunk should be present"
);
}
}
#[tokio::test]
async
fn
test_cmpl_nonstreaming_has_usage_and_cached_tokens
()
{
// Non-streaming completions must include usage in final aggregated response and propagate cached_tokens
let
mut
request
=
create_cmpl_request
(
None
,
false
);
// Simulate preprocessor behavior for non-streaming
let
original_stream_flag
=
request
.inner.stream
.unwrap_or
(
false
);
request
.enable_usage_for_nonstreaming
(
original_stream_flag
);
let
request_id
=
"cmpl-nonstream-usage"
.to_string
();
let
response_generator
=
Box
::
new
(
request
.response_generator
(
request_id
));
// Mock backend stream with 3 chunks, last carries completion_usage with cached_tokens
let
ctx
=
Arc
::
new
(
MockContext
::
new
());
let
backend_stream
=
create_backend_stream_with_cached_tokens
(
ctx
.clone
(),
Some
(
9
));
// Transform to OpenAI completion stream
let
transformed_stream
=
OpenAIPreprocessor
::
transform_postprocessor_stream
(
backend_stream
,
response_generator
,
ctx
.clone
(),
);
// Aggregate into a single non-streaming response
let
parsing
=
ParsingOptions
::
default
();
let
result
=
dynamo_llm
::
protocols
::
openai
::
completions
::
NvCreateCompletionResponse
::
from_annotated_stream
(
transformed_stream
,
parsing
,
)
.await
;
assert
!
(
result
.is_ok
(),
"Aggregation should succeed"
);
let
resp
=
result
.unwrap
();
let
usage
=
resp
.inner
.usage
.expect
(
"usage must be present for non-streaming"
);
assert_eq!
(
usage
.completion_tokens
,
3
,
"completion_tokens must aggregate"
);
let
cached
=
usage
.prompt_tokens_details
.and_then
(|
d
|
d
.cached_tokens
);
assert_eq!
(
cached
,
Some
(
9
),
"cached_tokens must propagate to non-streaming response"
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment