Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
be9d6b2b
Unverified
Commit
be9d6b2b
authored
Nov 19, 2025
by
Vladislav Nosivskoy
Committed by
GitHub
Nov 18, 2025
Browse files
feat: support prompt_tokens_details in usage (#4239)
Signed-off-by:
Vladislav Nosivskoy
<
vladnosiv@gmail.com
>
parent
0f4d7634
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
432 additions
and
24 deletions
+432
-24
components/src/dynamo/router/__main__.py
components/src/dynamo/router/__main__.py
+1
-0
components/src/dynamo/sglang/request_handlers/llm/decode_handler.py
.../src/dynamo/sglang/request_handlers/llm/decode_handler.py
+13
-0
components/src/dynamo/trtllm/main.py
components/src/dynamo/trtllm/main.py
+5
-0
components/src/dynamo/trtllm/request_handlers/handler_base.py
...onents/src/dynamo/trtllm/request_handlers/handler_base.py
+39
-2
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+48
-7
lib/llm/src/backend.rs
lib/llm/src/backend.rs
+1
-0
lib/llm/src/kv_router/prefill_router.rs
lib/llm/src/kv_router/prefill_router.rs
+26
-7
lib/llm/src/migration.rs
lib/llm/src/migration.rs
+1
-0
lib/llm/src/mocker/engine.rs
lib/llm/src/mocker/engine.rs
+1
-0
lib/llm/src/protocols/common/llm_backend.rs
lib/llm/src/protocols/common/llm_backend.rs
+13
-0
lib/llm/src/protocols/common/preprocessor.rs
lib/llm/src/protocols/common/preprocessor.rs
+11
-2
lib/llm/src/protocols/openai/chat_completions/delta.rs
lib/llm/src/protocols/openai/chat_completions/delta.rs
+10
-0
lib/llm/src/protocols/openai/completions/delta.rs
lib/llm/src/protocols/openai/completions/delta.rs
+10
-0
lib/llm/tests/test_streaming_usage.rs
lib/llm/tests/test_streaming_usage.rs
+253
-6
No files found.
components/src/dynamo/router/__main__.py
View file @
be9d6b2b
...
@@ -120,6 +120,7 @@ class StandaloneRouterHandler:
...
@@ -120,6 +120,7 @@ class StandaloneRouterHandler:
"index"
:
worker_output
.
get
(
"index"
),
"index"
:
worker_output
.
get
(
"index"
),
"disaggregated_params"
:
worker_output
.
get
(
"disaggregated_params"
),
"disaggregated_params"
:
worker_output
.
get
(
"disaggregated_params"
),
"extra_args"
:
worker_output
.
get
(
"extra_args"
),
"extra_args"
:
worker_output
.
get
(
"extra_args"
),
"completion_usage"
:
worker_output
.
get
(
"completion_usage"
),
}
}
yield
llm_engine_output
yield
llm_engine_output
...
...
components/src/dynamo/sglang/request_handlers/llm/decode_handler.py
View file @
be9d6b2b
...
@@ -229,6 +229,19 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -229,6 +229,19 @@ class DecodeWorkerHandler(BaseWorkerHandler):
next_total_toks
=
len
(
output_ids
)
next_total_toks
=
len
(
output_ids
)
out
[
"token_ids"
]
=
output_ids
[
num_output_tokens_so_far
:]
out
[
"token_ids"
]
=
output_ids
[
num_output_tokens_so_far
:]
num_output_tokens_so_far
=
next_total_toks
num_output_tokens_so_far
=
next_total_toks
if
finish_reason
:
input_tokens
=
res
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
res
[
"meta_info"
][
"completion_tokens"
]
cached_tokens
=
res
[
"meta_info"
][
"cached_tokens"
]
prefill_prompt_tokens_details
=
None
if
cached_tokens
is
not
None
and
cached_tokens
>
0
:
prefill_prompt_tokens_details
=
{
"cached_tokens"
:
cached_tokens
}
out
[
"completion_usage"
]
=
{
"prompt_tokens"
:
input_tokens
,
"completion_tokens"
:
completion_tokens
,
"total_tokens"
:
input_tokens
+
completion_tokens
,
"prompt_tokens_details"
:
prefill_prompt_tokens_details
,
}
if
not
context
.
is_stopped
():
if
not
context
.
is_stopped
():
yield
out
yield
out
...
...
components/src/dynamo/trtllm/main.py
View file @
be9d6b2b
...
@@ -242,6 +242,10 @@ async def init(runtime: DistributedRuntime, config: Config):
...
@@ -242,6 +242,10 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params
=
SamplingParams
()
default_sampling_params
=
SamplingParams
()
default_sampling_params
.
_setup
(
tokenizer
)
default_sampling_params
.
_setup
(
tokenizer
)
default_sampling_params
.
stop
=
None
default_sampling_params
.
stop
=
None
# Enable perf metrics so prompt_tokens_details can be returned
if
hasattr
(
default_sampling_params
,
"return_perf_metrics"
):
default_sampling_params
.
return_perf_metrics
=
True
model_input
=
ModelInput
.
Tokens
model_input
=
ModelInput
.
Tokens
# Set model type based on disaggregation mode for unified frontend support
# Set model type based on disaggregation mode for unified frontend support
...
@@ -356,6 +360,7 @@ async def init(runtime: DistributedRuntime, config: Config):
...
@@ -356,6 +360,7 @@ async def init(runtime: DistributedRuntime, config: Config):
connector
=
connector
,
connector
=
connector
,
runtime
=
runtime
,
# Pass runtime for graceful shutdown
runtime
=
runtime
,
# Pass runtime for graceful shutdown
metrics_collector
=
metrics_collector
,
metrics_collector
=
metrics_collector
,
kv_block_size
=
config
.
kv_block_size
,
)
)
# Register the model with runtime config
# Register the model with runtime config
...
...
components/src/dynamo/trtllm/request_handlers/handler_base.py
View file @
be9d6b2b
...
@@ -72,6 +72,7 @@ class RequestHandlerConfig:
...
@@ -72,6 +72,7 @@ class RequestHandlerConfig:
DistributedRuntime
DistributedRuntime
]
=
None
# DistributedRuntime reference for graceful shutdown
]
=
None
# DistributedRuntime reference for graceful shutdown
metrics_collector
:
Optional
[
Any
]
=
None
# TensorRT-LLM MetricsCollector
metrics_collector
:
Optional
[
Any
]
=
None
# TensorRT-LLM MetricsCollector
kv_block_size
:
int
=
32
class
HandlerBase
:
class
HandlerBase
:
...
@@ -92,6 +93,7 @@ class HandlerBase:
...
@@ -92,6 +93,7 @@ class HandlerBase:
self
.
connector
=
config
.
connector
self
.
connector
=
config
.
connector
# Store runtime reference for graceful shutdown
# Store runtime reference for graceful shutdown
self
.
runtime
=
config
.
runtime
self
.
runtime
=
config
.
runtime
self
.
kv_block_size
:
int
=
config
.
kv_block_size
def
check_error
(
self
,
result
:
dict
):
def
check_error
(
self
,
result
:
dict
):
"""
"""
...
@@ -208,11 +210,13 @@ class HandlerBase:
...
@@ -208,11 +210,13 @@ class HandlerBase:
request
[
"stop_conditions"
][
"max_tokens"
]
=
1
request
[
"stop_conditions"
][
"max_tokens"
]
=
1
disaggregated_params
=
LlmDisaggregatedParams
(
request_type
=
"context_only"
)
disaggregated_params
=
LlmDisaggregatedParams
(
request_type
=
"context_only"
)
if
"
disaggregated_params
"
in
request
:
if
"
prefill_result
"
in
request
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
raise
ValueError
(
"Cannot provide disaggregated_params in prefill mode"
)
raise
ValueError
(
"Cannot provide disaggregated_params in prefill mode"
)
disaggregated_params
=
DisaggregatedParamsCodec
.
decode
(
disaggregated_params
=
DisaggregatedParamsCodec
.
decode
(
DisaggregatedParams
(
**
request
[
"disaggregated_params"
])
DisaggregatedParams
(
**
request
[
"prefill_result"
].
get
(
"disaggregated_params"
)
)
)
)
disaggregated_params
.
request_type
=
"generation_only"
disaggregated_params
.
request_type
=
"generation_only"
...
@@ -258,6 +262,11 @@ class HandlerBase:
...
@@ -258,6 +262,11 @@ class HandlerBase:
adapters
=
create_trtllm_adapters
(
processors
)
adapters
=
create_trtllm_adapters
(
processors
)
sampling_params
.
logits_processor
=
adapters
sampling_params
.
logits_processor
=
adapters
prefill_result
=
request
.
get
(
"prefill_result"
)
prefill_prompt_tokens_details
=
(
prefill_result
.
get
(
"prompt_tokens_details"
)
if
prefill_result
else
None
)
try
:
try
:
# NEW: Updated engine call to include multimodal data
# NEW: Updated engine call to include multimodal data
generation_result
=
self
.
engine
.
llm
.
generate_async
(
generation_result
=
self
.
engine
.
llm
.
generate_async
(
...
@@ -298,6 +307,34 @@ class HandlerBase:
...
@@ -298,6 +307,34 @@ class HandlerBase:
DisaggregatedParamsCodec
.
encode
(
output
.
disaggregated_params
)
DisaggregatedParamsCodec
.
encode
(
output
.
disaggregated_params
)
)
)
if
out
.
get
(
"finish_reason"
):
num_input_tokens
=
len
(
request
.
get
(
"token_ids"
,
[]))
prompt_tokens_details
=
None
if
prefill_prompt_tokens_details
:
prompt_tokens_details
=
prefill_prompt_tokens_details
else
:
if
output
.
request_perf_metrics
is
not
None
:
kv_cache_metrics
=
(
output
.
request_perf_metrics
.
kv_cache_metrics
)
cached_tokens
=
min
(
num_input_tokens
,
kv_cache_metrics
.
num_reused_blocks
*
self
.
kv_block_size
,
)
if
cached_tokens
>
0
:
prompt_tokens_details
=
{
"cached_tokens"
:
int
(
cached_tokens
),
}
out
[
"completion_usage"
]
=
{
"prompt_tokens"
:
int
(
num_input_tokens
),
"completion_tokens"
:
int
(
next_total_toks
),
"total_tokens"
:
int
(
num_input_tokens
+
next_total_toks
),
"prompt_tokens_details"
:
prompt_tokens_details
,
}
if
res
.
finished
and
not
out
.
get
(
"finish_reason"
):
if
res
.
finished
and
not
out
.
get
(
"finish_reason"
):
out
[
"finish_reason"
]
=
"unknown"
out
[
"finish_reason"
]
=
"unknown"
logging
.
warning
(
logging
.
warning
(
...
...
components/src/dynamo/vllm/handlers.py
View file @
be9d6b2b
...
@@ -10,6 +10,7 @@ from contextlib import asynccontextmanager
...
@@ -10,6 +10,7 @@ from contextlib import asynccontextmanager
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
Final
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
Final
from
vllm.inputs
import
TokensPrompt
from
vllm.inputs
import
TokensPrompt
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.engine.exceptions
import
EngineDeadError
...
@@ -174,6 +175,28 @@ class BaseWorkerHandler(ABC):
...
@@ -174,6 +175,28 @@ class BaseWorkerHandler(ABC):
return
vllm_mm_data
if
vllm_mm_data
else
None
return
vllm_mm_data
if
vllm_mm_data
else
None
@
staticmethod
def
_build_completion_usage
(
request_output
:
RequestOutput
)
->
Dict
[
str
,
Any
]:
return
{
"prompt_tokens"
:
(
len
(
request_output
.
prompt_token_ids
)
if
request_output
.
prompt_token_ids
else
None
),
"completion_tokens"
:
len
(
request_output
.
outputs
[
0
].
token_ids
),
"total_tokens"
:
(
len
(
request_output
.
prompt_token_ids
)
+
len
(
request_output
.
outputs
[
0
].
token_ids
)
if
request_output
.
prompt_token_ids
else
None
),
"prompt_tokens_details"
:
(
{
"cached_tokens"
:
request_output
.
num_cached_tokens
}
if
request_output
.
num_cached_tokens
else
None
),
}
async
def
generate_tokens
(
async
def
generate_tokens
(
self
,
prompt
,
sampling_params
,
request_id
,
data_parallel_rank
=
None
self
,
prompt
,
sampling_params
,
request_id
,
data_parallel_rank
=
None
):
):
...
@@ -199,6 +222,11 @@ class BaseWorkerHandler(ABC):
...
@@ -199,6 +222,11 @@ class BaseWorkerHandler(ABC):
out
=
{
"token_ids"
:
output
.
token_ids
[
num_output_tokens_so_far
:]}
out
=
{
"token_ids"
:
output
.
token_ids
[
num_output_tokens_so_far
:]}
if
output
.
finish_reason
:
if
output
.
finish_reason
:
out
[
"finish_reason"
]
=
output
.
finish_reason
out
[
"finish_reason"
]
=
output
.
finish_reason
out
[
"completion_usage"
]
=
BaseWorkerHandler
.
_build_completion_usage
(
request_output
=
res
)
if
output
.
stop_reason
:
if
output
.
stop_reason
:
out
[
"stop_reason"
]
=
output
.
stop_reason
out
[
"stop_reason"
]
=
output
.
stop_reason
yield
out
yield
out
...
@@ -241,18 +269,24 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -241,18 +269,24 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Build sampling params from request
# Build sampling params from request
sampling_params
=
build_sampling_params
(
request
,
self
.
default_sampling_params
)
sampling_params
=
build_sampling_params
(
request
,
self
.
default_sampling_params
)
# Extract disaggregated_params from request (set by prefill router in Rust frontend)
prefill_result
=
request
.
get
(
"prefill_result"
)
disaggregated_params
=
request
.
get
(
"disaggregated_params"
)
if
prefill_result
and
isinstance
(
prefill_result
,
dict
):
if
disaggregated_params
:
kv_params
=
prefill_result
.
get
(
"disaggregated_params"
,
{}).
get
(
# Prefill was performed - use the disaggregated params
if
sampling_params
.
extra_args
is
None
:
sampling_params
.
extra_args
=
{}
sampling_params
.
extra_args
[
"kv_transfer_params"
]
=
disaggregated_params
.
get
(
"kv_transfer_params"
"kv_transfer_params"
)
)
else
:
kv_params
=
None
if
kv_params
is
not
None
:
if
sampling_params
.
extra_args
is
None
:
sampling_params
.
extra_args
=
{}
sampling_params
.
extra_args
[
"kv_transfer_params"
]
=
kv_params
logger
.
debug
(
logger
.
debug
(
f
"Using disaggregated params from prefill for request
{
request_id
}
"
f
"Using disaggregated params from prefill for request
{
request_id
}
"
)
)
prefill_prompt_tokens_details
=
(
prefill_result
.
get
(
"prompt_tokens_details"
)
if
prefill_result
else
None
)
dp_rank
=
request
.
get
(
"dp_rank"
,
None
)
dp_rank
=
request
.
get
(
"dp_rank"
,
None
)
...
@@ -261,6 +295,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -261,6 +295,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
async
for
tok
in
self
.
generate_tokens
(
async
for
tok
in
self
.
generate_tokens
(
prompt
,
sampling_params
,
request_id
,
data_parallel_rank
=
dp_rank
prompt
,
sampling_params
,
request_id
,
data_parallel_rank
=
dp_rank
):
):
if
prefill_result
is
not
None
and
"completion_usage"
in
tok
:
tok
[
"completion_usage"
][
"prompt_tokens_details"
]
=
prefill_prompt_tokens_details
yield
tok
yield
tok
except
EngineDeadError
as
e
:
except
EngineDeadError
as
e
:
logger
.
error
(
f
"vLLM EngineDeadError:
{
e
}
"
)
logger
.
error
(
f
"vLLM EngineDeadError:
{
e
}
"
)
...
@@ -325,6 +363,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
...
@@ -325,6 +363,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
if
res
.
kv_transfer_params
if
res
.
kv_transfer_params
else
None
else
None
),
),
"completion_usage"
:
BaseWorkerHandler
.
_build_completion_usage
(
request_output
=
res
),
}
}
yield
output
yield
output
...
...
lib/llm/src/backend.rs
View file @
be9d6b2b
...
@@ -242,6 +242,7 @@ impl
...
@@ -242,6 +242,7 @@ impl
finish_reason
:
data
.finish_reason
,
finish_reason
:
data
.finish_reason
,
//mdcsum: mdcsum.clone(),
//mdcsum: mdcsum.clone(),
index
:
data
.index
,
index
:
data
.index
,
completion_usage
:
data
.completion_usage
,
})
})
})
})
});
});
...
...
lib/llm/src/kv_router/prefill_router.rs
View file @
be9d6b2b
...
@@ -21,6 +21,7 @@ use crate::{
...
@@ -21,6 +21,7 @@ use crate::{
discovery
::
ModelManager
,
discovery
::
ModelManager
,
kv_router
::{
KvPushRouter
,
KvRouterConfig
,
RouterConfigOverride
},
kv_router
::{
KvPushRouter
,
KvRouterConfig
,
RouterConfigOverride
},
protocols
::
common
::
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
},
protocols
::
common
::
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
},
protocols
::
common
::
preprocessor
::
PrefillResult
,
};
};
/// Errors that can occur during prefill routing
/// Errors that can occur during prefill routing
...
@@ -175,11 +176,11 @@ impl PrefillRouter {
...
@@ -175,11 +176,11 @@ impl PrefillRouter {
Ok
(())
Ok
(())
}
}
/// Call the prefill router and extract
disaggregated_params
/// Call the prefill router and extract
structured prefill result
async
fn
call_prefill
(
async
fn
call_prefill
(
&
self
,
&
self
,
request
:
SingleIn
<
PreprocessedRequest
>
,
request
:
SingleIn
<
PreprocessedRequest
>
,
)
->
Result
<
serde_json
::
Value
,
PrefillError
>
{
)
->
Result
<
PrefillResult
,
PrefillError
>
{
// Get the prefill router, error if not activated
// Get the prefill router, error if not activated
let
Some
(
prefill_router
)
=
self
.prefill_router
.get
()
else
{
let
Some
(
prefill_router
)
=
self
.prefill_router
.get
()
else
{
return
Err
(
PrefillError
::
NotActivated
);
return
Err
(
PrefillError
::
NotActivated
);
...
@@ -203,7 +204,22 @@ impl PrefillRouter {
...
@@ -203,7 +204,22 @@ impl PrefillRouter {
));
));
};
};
while
prefill_response
.next
()
.await
.is_some
()
{}
let
mut
prompt_tokens_details
=
first_output
.data
.as_ref
()
.and_then
(|
o
|
o
.completion_usage
.as_ref
())
.and_then
(|
u
|
u
.prompt_tokens_details
.clone
());
while
let
Some
(
next
)
=
prefill_response
.next
()
.await
{
if
let
Some
(
o
)
=
next
.data
.as_ref
()
&&
prompt_tokens_details
.is_none
()
{
prompt_tokens_details
=
o
.completion_usage
.as_ref
()
.and_then
(|
u
|
u
.prompt_tokens_details
.clone
());
}
}
if
let
Some
(
err
)
=
first_output
.err
()
{
if
let
Some
(
err
)
=
first_output
.err
()
{
return
Err
(
PrefillError
::
PrefillError
(
format!
(
return
Err
(
PrefillError
::
PrefillError
(
format!
(
...
@@ -223,7 +239,10 @@ impl PrefillRouter {
...
@@ -223,7 +239,10 @@ impl PrefillRouter {
));
));
};
};
Ok
(
disaggregated_params
)
Ok
(
PrefillResult
{
disaggregated_params
,
prompt_tokens_details
,
})
}
}
}
}
...
@@ -267,12 +286,12 @@ impl
...
@@ -267,12 +286,12 @@ impl
// Attempt prefill and handle results
// Attempt prefill and handle results
match
self
.call_prefill
(
prefill_request
)
.await
{
match
self
.call_prefill
(
prefill_request
)
.await
{
Ok
(
disaggregated_params
)
=>
{
Ok
(
prefill_result
)
=>
{
tracing
::
debug!
(
"Prefill succeeded, using disaggregated params for decode"
);
tracing
::
debug!
(
"Prefill succeeded, using disaggregated params for decode"
);
// Update request with disaggregated_params and router config
let
mut
decode_req
=
req
;
let
mut
decode_req
=
req
;
decode_req
.disaggregated_params
=
Some
(
disaggregated_params
);
// Update request with prefill result
decode_req
.prefill_result
=
Some
(
prefill_result
.clone
());
// Restore original max_tokens for decode
// Restore original max_tokens for decode
decode_req
.stop_conditions.max_tokens
=
original_max_tokens
;
decode_req
.stop_conditions.max_tokens
=
original_max_tokens
;
...
...
lib/llm/src/migration.rs
View file @
be9d6b2b
...
@@ -219,6 +219,7 @@ mod tests {
...
@@ -219,6 +219,7 @@ mod tests {
index
:
None
,
index
:
None
,
disaggregated_params
:
None
,
disaggregated_params
:
None
,
extra_args
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
})
})
}
}
...
...
lib/llm/src/mocker/engine.rs
View file @
be9d6b2b
...
@@ -308,6 +308,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
...
@@ -308,6 +308,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
None
None
},
},
extra_args
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
};
};
if
signal
.completed
&&
token_count
<
max_output_tokens
{
if
signal
.completed
&&
token_count
<
max_output_tokens
{
...
...
lib/llm/src/protocols/common/llm_backend.rs
View file @
be9d6b2b
...
@@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
...
@@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
pub
use
super
::
FinishReason
;
pub
use
super
::
FinishReason
;
pub
use
super
::
preprocessor
::
PreprocessedRequest
;
pub
use
super
::
preprocessor
::
PreprocessedRequest
;
use
crate
::
protocols
::
TokenIdType
;
use
crate
::
protocols
::
TokenIdType
;
use
dynamo_async_openai
::
types
::
CompletionUsage
;
use
dynamo_runtime
::
protocols
::
maybe_error
::
MaybeError
;
use
dynamo_runtime
::
protocols
::
maybe_error
::
MaybeError
;
pub
type
TokenType
=
Option
<
String
>
;
pub
type
TokenType
=
Option
<
String
>
;
...
@@ -48,6 +49,10 @@ pub struct BackendOutput {
...
@@ -48,6 +49,10 @@ pub struct BackendOutput {
// Index field for batch requests to match OpenAI format
// Index field for batch requests to match OpenAI format
pub
index
:
Option
<
u32
>
,
pub
index
:
Option
<
u32
>
,
// Token usage information
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
completion_usage
:
Option
<
CompletionUsage
>
,
}
}
/// The LLM engine and backnd with manage it's own state, specifically translating how a
/// The LLM engine and backnd with manage it's own state, specifically translating how a
...
@@ -92,6 +97,10 @@ pub struct LLMEngineOutput {
...
@@ -92,6 +97,10 @@ pub struct LLMEngineOutput {
/// Additional arguments for extensibility
/// Additional arguments for extensibility
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
extra_args
:
Option
<
serde_json
::
Value
>
,
pub
extra_args
:
Option
<
serde_json
::
Value
>
,
// Token usage information
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
completion_usage
:
Option
<
CompletionUsage
>
,
}
}
impl
LLMEngineOutput
{
impl
LLMEngineOutput
{
...
@@ -107,6 +116,7 @@ impl LLMEngineOutput {
...
@@ -107,6 +116,7 @@ impl LLMEngineOutput {
index
:
None
,
index
:
None
,
disaggregated_params
:
None
,
disaggregated_params
:
None
,
extra_args
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
}
}
}
}
...
@@ -122,6 +132,7 @@ impl LLMEngineOutput {
...
@@ -122,6 +132,7 @@ impl LLMEngineOutput {
index
:
None
,
index
:
None
,
disaggregated_params
:
None
,
disaggregated_params
:
None
,
extra_args
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
}
}
}
}
...
@@ -137,6 +148,7 @@ impl LLMEngineOutput {
...
@@ -137,6 +148,7 @@ impl LLMEngineOutput {
index
:
None
,
index
:
None
,
disaggregated_params
:
None
,
disaggregated_params
:
None
,
extra_args
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
}
}
}
}
...
@@ -152,6 +164,7 @@ impl LLMEngineOutput {
...
@@ -152,6 +164,7 @@ impl LLMEngineOutput {
index
:
None
,
index
:
None
,
disaggregated_params
:
None
,
disaggregated_params
:
None
,
extra_args
:
None
,
extra_args
:
None
,
completion_usage
:
None
,
}
}
}
}
}
}
...
...
lib/llm/src/protocols/common/preprocessor.rs
View file @
be9d6b2b
...
@@ -8,6 +8,15 @@ use super::{OutputOptions, SamplingOptions, StopConditions};
...
@@ -8,6 +8,15 @@ use super::{OutputOptions, SamplingOptions, StopConditions};
use
crate
::
kv_router
::
RouterConfigOverride
;
use
crate
::
kv_router
::
RouterConfigOverride
;
use
crate
::
protocols
::
TokenIdType
;
use
crate
::
protocols
::
TokenIdType
;
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
PrefillResult
{
/// Disaggregated execution parameters
pub
disaggregated_params
:
serde_json
::
Value
,
/// Prompt token details produced during prefill
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
prompt_tokens_details
:
Option
<
dynamo_async_openai
::
types
::
PromptTokensDetails
>
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
enum
MultimodalData
{
pub
enum
MultimodalData
{
Url
(
url
::
Url
),
Url
(
url
::
Url
),
...
@@ -69,10 +78,10 @@ pub struct PreprocessedRequest {
...
@@ -69,10 +78,10 @@ pub struct PreprocessedRequest {
#[builder(default)]
#[builder(default)]
pub
router_config_override
:
Option
<
RouterConfigOverride
>
,
pub
router_config_override
:
Option
<
RouterConfigOverride
>
,
///
Disaggregated execution parameters (for prefill/decode separation)
///
Structured prefill result
#[builder(default)]
#[builder(default)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
disaggregated_params
:
Option
<
serde_json
::
Value
>
,
pub
prefill_result
:
Option
<
PrefillResult
>
,
/// Data parallel rank for the request (used with data parallelism)
/// Data parallel rank for the request (used with data parallelism)
#[builder(default)]
#[builder(default)]
...
...
lib/llm/src/protocols/openai/chat_completions/delta.rs
View file @
be9d6b2b
...
@@ -316,6 +316,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
...
@@ -316,6 +316,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
.expect
(
"token_ids length exceeds u32::MAX"
);
.expect
(
"token_ids length exceeds u32::MAX"
);
self
.usage.completion_tokens
+=
token_length
;
self
.usage.completion_tokens
+=
token_length
;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if
let
Some
(
prompt_details
)
=
delta
.completion_usage
.as_ref
()
.and_then
(|
usage
|
usage
.prompt_tokens_details
.as_ref
())
{
self
.usage.prompt_tokens_details
=
Some
(
prompt_details
.clone
());
}
}
}
let
logprobs
=
self
.create_logprobs
(
let
logprobs
=
self
.create_logprobs
(
...
...
lib/llm/src/protocols/openai/completions/delta.rs
View file @
be9d6b2b
...
@@ -238,6 +238,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
...
@@ -238,6 +238,16 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
.expect
(
"token_ids length exceeds u32::MAX"
);
.expect
(
"token_ids length exceeds u32::MAX"
);
self
.usage.completion_tokens
+=
token_length
;
self
.usage.completion_tokens
+=
token_length
;
// If backend provides completion_usage with prompt token details,
// propagate the entire details struct to usage tracking
if
let
Some
(
prompt_details
)
=
delta
.completion_usage
.as_ref
()
.and_then
(|
usage
|
usage
.prompt_tokens_details
.as_ref
())
{
self
.usage.prompt_tokens_details
=
Some
(
prompt_details
.clone
());
}
}
}
let
logprobs
=
self
.create_logprobs
(
let
logprobs
=
self
.create_logprobs
(
...
...
lib/llm/tests/test_streaming_usage.rs
View file @
be9d6b2b
...
@@ -7,12 +7,17 @@ use dynamo_async_openai::types::{
...
@@ -7,12 +7,17 @@ use dynamo_async_openai::types::{
ChatCompletionRequestUserMessageContent
,
ChatCompletionStreamOptions
,
ChatCompletionRequestUserMessageContent
,
ChatCompletionStreamOptions
,
CreateChatCompletionRequest
,
CreateChatCompletionRequest
,
};
};
use
dynamo_async_openai
::
types
::{
CompletionUsage
as
AoaiCompletionUsage
,
CreateCompletionRequestArgs
,
Prompt
,
PromptTokensDetails
,
};
use
dynamo_llm
::
preprocessor
::
OpenAIPreprocessor
;
use
dynamo_llm
::
preprocessor
::
OpenAIPreprocessor
;
use
dynamo_llm
::
protocols
::
common
::
llm_backend
::{
BackendOutput
,
FinishReason
};
use
dynamo_llm
::
protocols
::
common
::
llm_backend
::{
BackendOutput
,
FinishReason
};
use
dynamo_llm
::
protocols
::
openai
::
ParsingOptions
;
use
dynamo_llm
::
protocols
::
openai
::
ParsingOptions
;
use
dynamo_llm
::
protocols
::
openai
::
chat_completions
::{
use
dynamo_llm
::
protocols
::
openai
::
chat_completions
::{
NvCreateChatCompletionRequest
,
aggregator
::
ChatCompletionAggregator
,
NvCreateChatCompletionRequest
,
aggregator
::
ChatCompletionAggregator
,
};
};
use
dynamo_llm
::
protocols
::
openai
::
completions
::
NvCreateCompletionRequest
;
use
dynamo_runtime
::
engine
::{
AsyncEngineContext
,
AsyncEngineStream
};
use
dynamo_runtime
::
engine
::{
AsyncEngineContext
,
AsyncEngineStream
};
use
dynamo_runtime
::
protocols
::
annotated
::
Annotated
;
use
dynamo_runtime
::
protocols
::
annotated
::
Annotated
;
use
futures
::
StreamExt
;
use
futures
::
StreamExt
;
...
@@ -82,8 +87,17 @@ impl AsyncEngineContext for MockContext {
...
@@ -82,8 +87,17 @@ impl AsyncEngineContext for MockContext {
fn
create_mock_backend_stream
(
fn
create_mock_backend_stream
(
ctx
:
Arc
<
dyn
AsyncEngineContext
>
,
ctx
:
Arc
<
dyn
AsyncEngineContext
>
,
)
->
Pin
<
Box
<
dyn
AsyncEngineStream
<
Annotated
<
BackendOutput
>>>>
{
)
->
Pin
<
Box
<
dyn
AsyncEngineStream
<
Annotated
<
BackendOutput
>>>>
{
let
outputs
=
vec!
[
let
outputs
=
build_backend_outputs_with_cached_tokens
(
None
);
// First chunk with "Hello"
let
stream
=
stream
::
iter
(
outputs
.into_iter
()
.map
(
Annotated
::
from_data
));
use
dynamo_runtime
::
engine
::
ResponseStream
;
ResponseStream
::
new
(
Box
::
pin
(
stream
),
ctx
)
}
/// Build three backend outputs: "Hello", " world", "!" with optional cached_tokens on the final chunk
fn
build_backend_outputs_with_cached_tokens
(
cached_tokens
:
Option
<
u32
>
)
->
Vec
<
BackendOutput
>
{
vec!
[
BackendOutput
{
BackendOutput
{
token_ids
:
vec!
[
15339
],
token_ids
:
vec!
[
15339
],
tokens
:
vec!
[
Some
(
"Hello"
.to_string
())],
tokens
:
vec!
[
Some
(
"Hello"
.to_string
())],
...
@@ -93,8 +107,8 @@ fn create_mock_backend_stream(
...
@@ -93,8 +107,8 @@ fn create_mock_backend_stream(
top_logprobs
:
None
,
top_logprobs
:
None
,
finish_reason
:
None
,
finish_reason
:
None
,
index
:
Some
(
0
),
index
:
Some
(
0
),
completion_usage
:
None
,
},
},
// Second chunk with " world"
BackendOutput
{
BackendOutput
{
token_ids
:
vec!
[
1917
],
token_ids
:
vec!
[
1917
],
tokens
:
vec!
[
Some
(
" world"
.to_string
())],
tokens
:
vec!
[
Some
(
" world"
.to_string
())],
...
@@ -104,8 +118,8 @@ fn create_mock_backend_stream(
...
@@ -104,8 +118,8 @@ fn create_mock_backend_stream(
top_logprobs
:
None
,
top_logprobs
:
None
,
finish_reason
:
None
,
finish_reason
:
None
,
index
:
Some
(
0
),
index
:
Some
(
0
),
completion_usage
:
None
,
},
},
// Third chunk with "!" and finish_reason
BackendOutput
{
BackendOutput
{
token_ids
:
vec!
[
0
],
token_ids
:
vec!
[
0
],
tokens
:
vec!
[
Some
(
"!"
.to_string
())],
tokens
:
vec!
[
Some
(
"!"
.to_string
())],
...
@@ -115,11 +129,27 @@ fn create_mock_backend_stream(
...
@@ -115,11 +129,27 @@ fn create_mock_backend_stream(
top_logprobs
:
None
,
top_logprobs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Stop
),
finish_reason
:
Some
(
FinishReason
::
Stop
),
index
:
Some
(
0
),
index
:
Some
(
0
),
completion_usage
:
cached_tokens
.map
(|
ct
|
AoaiCompletionUsage
{
prompt_tokens
:
0
,
completion_tokens
:
0
,
total_tokens
:
0
,
prompt_tokens_details
:
Some
(
PromptTokensDetails
{
audio_tokens
:
None
,
cached_tokens
:
Some
(
ct
),
}),
completion_tokens_details
:
None
,
}),
},
},
];
]
}
/// Create a backend stream from standard outputs with optional cached_tokens in the final chunk
fn
create_backend_stream_with_cached_tokens
(
ctx
:
Arc
<
dyn
AsyncEngineContext
>
,
cached_tokens
:
Option
<
u32
>
,
)
->
Pin
<
Box
<
dyn
AsyncEngineStream
<
Annotated
<
BackendOutput
>>>>
{
let
outputs
=
build_backend_outputs_with_cached_tokens
(
cached_tokens
);
let
stream
=
stream
::
iter
(
outputs
.into_iter
()
.map
(
Annotated
::
from_data
));
let
stream
=
stream
::
iter
(
outputs
.into_iter
()
.map
(
Annotated
::
from_data
));
use
dynamo_runtime
::
engine
::
ResponseStream
;
use
dynamo_runtime
::
engine
::
ResponseStream
;
ResponseStream
::
new
(
Box
::
pin
(
stream
),
ctx
)
ResponseStream
::
new
(
Box
::
pin
(
stream
),
ctx
)
}
}
...
@@ -308,6 +338,31 @@ async fn test_streaming_with_usage_false() {
...
@@ -308,6 +338,31 @@ async fn test_streaming_with_usage_false() {
}
}
}
}
/// Helper to create a completion request with optional stream_options
fn
create_cmpl_request
(
include_usage
:
Option
<
bool
>
,
stream
:
bool
)
->
NvCreateCompletionRequest
{
let
inner
=
{
let
mut
builder
=
CreateCompletionRequestArgs
::
default
();
builder
.model
(
"test-model"
)
.prompt
(
Prompt
::
String
(
"Hello"
.to_string
()))
.stream
(
stream
);
if
let
Some
(
include
)
=
include_usage
{
builder
.stream_options
(
dynamo_async_openai
::
types
::
ChatCompletionStreamOptions
{
include_usage
:
include
,
});
}
builder
.build
()
.unwrap
()
};
NvCreateCompletionRequest
{
inner
,
common
:
Default
::
default
(),
nvext
:
None
,
metadata
:
None
,
unsupported_fields
:
Default
::
default
(),
}
}
/// Helper to create a non-streaming chat completion request
/// Helper to create a non-streaming chat completion request
fn
create_nonstreaming_chat_request
()
->
NvCreateChatCompletionRequest
{
fn
create_nonstreaming_chat_request
()
->
NvCreateChatCompletionRequest
{
let
messages
=
vec!
[
ChatCompletionRequestMessage
::
User
(
let
messages
=
vec!
[
ChatCompletionRequestMessage
::
User
(
...
@@ -404,3 +459,195 @@ async fn test_nonstreaming_has_usage_field() {
...
@@ -404,3 +459,195 @@ async fn test_nonstreaming_has_usage_field() {
"Total tokens should equal prompt_tokens + completion_tokens"
"Total tokens should equal prompt_tokens + completion_tokens"
);
);
}
}
#[tokio::test]
async
fn
test_cmpl_streaming_with_usage_true_no_backend_usage
()
{
// Completions: stream=true, include_usage=true, but backend does not send completion_usage
let
request
=
create_cmpl_request
(
Some
(
true
),
true
);
let
request_id
=
"cmpl-usage-none-1"
.to_string
();
let
response_generator
=
Box
::
new
(
request
.response_generator
(
request_id
));
// Mock backend stream (no completion_usage in any chunk)
let
ctx
=
Arc
::
new
(
MockContext
::
new
());
let
backend_stream
=
create_mock_backend_stream
(
ctx
.clone
());
// Transform
let
transformed_stream
=
OpenAIPreprocessor
::
transform_postprocessor_stream
(
backend_stream
,
response_generator
,
ctx
.clone
(),
);
let
chunks
:
Vec
<
_
>
=
transformed_stream
.collect
()
.await
;
// Expect 3 content chunks + 1 usage-only chunk
assert_eq!
(
chunks
.len
(),
4
,
"Should have 3 content + 1 usage chunk"
);
// First 3 chunks: usage must be None
for
(
i
,
chunk
)
in
chunks
.iter
()
.take
(
3
)
.enumerate
()
{
if
let
Some
(
resp
)
=
&
chunk
.data
{
assert
!
(
resp
.inner.usage
.is_none
(),
"Content chunk {} should have usage: None"
,
i
);
assert
!
(
!
resp
.inner.choices
.is_empty
(),
"Content chunk {} should have choices"
,
i
);
}
}
// Final usage chunk: usage present with counts; prompt_tokens_details None (no backend usage)
if
let
Some
(
final_resp
)
=
&
chunks
[
3
]
.data
{
assert
!
(
final_resp
.inner.choices
.is_empty
(),
"Usage-only chunk must have empty choices"
);
let
usage
=
final_resp
.inner
.usage
.as_ref
()
.expect
(
"Usage must be present"
);
assert_eq!
(
usage
.completion_tokens
,
3
,
"Aggregated completion tokens should be 3"
);
assert
!
(
usage
.prompt_tokens_details
.is_none
(),
"prompt_tokens_details should be None when backend does not send usage"
);
}
else
{
panic!
(
"Final chunk should be present"
);
}
}
#[tokio::test]
async
fn
test_cmpl_streaming_with_cached_tokens_propagation
()
{
// Completions: include_usage=true, backend provides cached_tokens -> must propagate
let
request
=
create_cmpl_request
(
Some
(
true
),
true
);
let
request_id
=
"cmpl-usage-cached-1"
.to_string
();
let
mut
response_generator
=
Box
::
new
(
request
.response_generator
(
request_id
));
// Build a backend stream where the final chunk carries completion_usage with cached_tokens
let
ctx
=
Arc
::
new
(
MockContext
::
new
());
let
backend_stream
=
create_backend_stream_with_cached_tokens
(
ctx
.clone
(),
Some
(
7
));
// Align ISL so total usage gets computed correctly
response_generator
.update_isl
(
0
);
let
transformed_stream
=
OpenAIPreprocessor
::
transform_postprocessor_stream
(
backend_stream
,
response_generator
,
ctx
.clone
(),
);
let
chunks
:
Vec
<
_
>
=
transformed_stream
.collect
()
.await
;
// Expect 4 chunks total
assert_eq!
(
chunks
.len
(),
4
,
"Should have 3 content + 1 usage chunk"
);
// Final usage chunk should include cached_tokens propagated
if
let
Some
(
final_resp
)
=
&
chunks
[
3
]
.data
{
let
usage
=
final_resp
.inner
.usage
.as_ref
()
.expect
(
"Usage must be present on final chunk"
);
let
cached
=
usage
.prompt_tokens_details
.as_ref
()
.and_then
(|
d
|
d
.cached_tokens
);
assert_eq!
(
cached
,
Some
(
7
),
"cached_tokens must propagate to final usage chunk"
);
}
else
{
panic!
(
"Final chunk should be present"
);
}
}
#[tokio::test]
async
fn
test_chat_streaming_with_cached_tokens_propagation
()
{
// Chat Completions: include_usage=true, backend provides cached_tokens -> must propagate
let
request
=
create_chat_request
(
Some
(
true
));
let
request_id
=
"chat-usage-cached-1"
.to_string
();
let
mut
response_generator
=
Box
::
new
(
request
.response_generator
(
request_id
));
let
ctx
=
Arc
::
new
(
MockContext
::
new
());
let
backend_stream
=
create_backend_stream_with_cached_tokens
(
ctx
.clone
(),
Some
(
5
));
// Align ISL if needed
response_generator
.update_isl
(
0
);
let
transformed_stream
=
OpenAIPreprocessor
::
transform_postprocessor_stream
(
backend_stream
,
response_generator
,
ctx
.clone
(),
);
let
chunks
:
Vec
<
_
>
=
transformed_stream
.collect
()
.await
;
assert_eq!
(
chunks
.len
(),
4
,
"Should have 3 content + 1 usage chunk"
);
if
let
Some
(
final_resp
)
=
&
chunks
[
3
]
.data
{
let
usage
=
final_resp
.usage
.as_ref
()
.expect
(
"Usage must be present"
);
let
cached
=
usage
.prompt_tokens_details
.as_ref
()
.and_then
(|
d
|
d
.cached_tokens
);
assert_eq!
(
cached
,
Some
(
5
),
"cached_tokens must propagate for chat completions"
);
}
else
{
panic!
(
"Final chunk should be present"
);
}
}
#[tokio::test]
async
fn
test_cmpl_nonstreaming_has_usage_and_cached_tokens
()
{
// Non-streaming completions must include usage in final aggregated response and propagate cached_tokens
let
mut
request
=
create_cmpl_request
(
None
,
false
);
// Simulate preprocessor behavior for non-streaming
let
original_stream_flag
=
request
.inner.stream
.unwrap_or
(
false
);
request
.enable_usage_for_nonstreaming
(
original_stream_flag
);
let
request_id
=
"cmpl-nonstream-usage"
.to_string
();
let
response_generator
=
Box
::
new
(
request
.response_generator
(
request_id
));
// Mock backend stream with 3 chunks, last carries completion_usage with cached_tokens
let
ctx
=
Arc
::
new
(
MockContext
::
new
());
let
backend_stream
=
create_backend_stream_with_cached_tokens
(
ctx
.clone
(),
Some
(
9
));
// Transform to OpenAI completion stream
let
transformed_stream
=
OpenAIPreprocessor
::
transform_postprocessor_stream
(
backend_stream
,
response_generator
,
ctx
.clone
(),
);
// Aggregate into a single non-streaming response
let
parsing
=
ParsingOptions
::
default
();
let
result
=
dynamo_llm
::
protocols
::
openai
::
completions
::
NvCreateCompletionResponse
::
from_annotated_stream
(
transformed_stream
,
parsing
,
)
.await
;
assert
!
(
result
.is_ok
(),
"Aggregation should succeed"
);
let
resp
=
result
.unwrap
();
let
usage
=
resp
.inner
.usage
.expect
(
"usage must be present for non-streaming"
);
assert_eq!
(
usage
.completion_tokens
,
3
,
"completion_tokens must aggregate"
);
let
cached
=
usage
.prompt_tokens_details
.and_then
(|
d
|
d
.cached_tokens
);
assert_eq!
(
cached
,
Some
(
9
),
"cached_tokens must propagate to non-streaming response"
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment