Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
01bfbea1
Unverified
Commit
01bfbea1
authored
Dec 09, 2025
by
Elijah Soba
Committed by
GitHub
Dec 09, 2025
Browse files
feat: Add logprobs support to TRTLLM backend (#4759)
Signed-off-by:
Elijah Soba
<
esoba@nvidia.com
>
parent
1e5b20b2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
241 additions
and
17 deletions
+241
-17
components/src/dynamo/trtllm/request_handlers/handler_base.py
...onents/src/dynamo/trtllm/request_handlers/handler_base.py
+99
-0
tests/serve/test_trtllm.py
tests/serve/test_trtllm.py
+32
-0
tests/utils/payload_builder.py
tests/utils/payload_builder.py
+44
-17
tests/utils/payloads.py
tests/utils/payloads.py
+66
-0
No files found.
components/src/dynamo/trtllm/request_handlers/handler_base.py
View file @
01bfbea1
...
...
@@ -106,6 +106,76 @@ class HandlerBase:
result
[
"finish_reason"
]
==
"stop"
or
result
[
"finish_reason"
]
==
"error"
)
@
staticmethod
def
_extract_logprobs
(
output
,
num_output_tokens_so_far
:
int
)
->
tuple
[
list
[
float
]
|
None
,
list
[
list
[
dict
]]
|
None
]:
"""
Extract logprobs from the TRTLLM output for new tokens.
Args:
output: TRTLLM CompletionOutput object
num_output_tokens_so_far: Number of tokens already processed
Returns:
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
- log_probs: List of log probabilities for each new token
- top_logprobs: List of top logprobs dicts for each new token
"""
if
output
.
logprobs
is
None
:
return
None
,
None
# Get logprobs for new tokens only
new_logprobs
=
output
.
logprobs
[
num_output_tokens_so_far
:]
if
not
new_logprobs
:
return
None
,
None
# From TRTLLM CompletionOutput API, logprobs: (TokenLogprobs | List[float], optional)
# Expect TokenLogprobs output when logprobs is set, check edge case where list[float] is returned instead
if
isinstance
(
new_logprobs
[
0
],
float
):
return
[
float
(
lp
)
for
lp
in
new_logprobs
],
None
log_probs
=
[]
top_logprobs
=
[]
for
token_idx
,
token_logprobs_dict
in
enumerate
(
new_logprobs
):
if
token_logprobs_dict
is
None
:
continue
# Get the actual token_id that was generated at this position
actual_token_id
=
output
.
token_ids
[
num_output_tokens_so_far
+
token_idx
]
# Extract log probability for the selected token
if
actual_token_id
in
token_logprobs_dict
:
selected_logprob
=
token_logprobs_dict
[
actual_token_id
]
log_probs
.
append
(
float
(
selected_logprob
.
logprob
))
else
:
# Fallback: use the first logprob if selected token not found
first_logprob
=
next
(
iter
(
token_logprobs_dict
.
values
()),
None
)
if
first_logprob
:
log_probs
.
append
(
float
(
first_logprob
.
logprob
))
# Build top_logprobs list for this token position
# NOTE: TRTLLM LogProb API doesn't have decoded_token, will default to None
token_top_logprobs
=
[]
for
tok_id
,
logprob_info
in
token_logprobs_dict
.
items
():
token_top_logprobs
.
append
(
{
"rank"
:
logprob_info
.
rank
if
hasattr
(
logprob_info
,
"rank"
)
else
0
,
"token_id"
:
tok_id
,
"token"
:
(
logprob_info
.
decoded_token
if
hasattr
(
logprob_info
,
"decoded_token"
)
else
None
),
"logprob"
:
float
(
logprob_info
.
logprob
),
}
)
top_logprobs
.
append
(
token_top_logprobs
)
return
log_probs
if
log_probs
else
None
,
top_logprobs
if
top_logprobs
else
None
async
def
_handle_cancellation
(
self
,
generation_result
:
GenerationResult
,
context
:
Context
):
...
...
@@ -236,6 +306,26 @@ class HandlerBase:
if
hasattr
(
sampling_params
,
key
):
setattr
(
sampling_params
,
key
,
value
)
# Additional sampling params in output options
output_options
=
request
.
get
(
"output_options"
,
{})
if
output_options
:
logprobs_value
=
output_options
.
get
(
"logprobs"
)
# Handle logprobs
if
logprobs_value
is
not
None
:
if
hasattr
(
sampling_params
,
"logprobs"
):
setattr
(
sampling_params
,
"logprobs"
,
max
(
1
,
int
(
logprobs_value
))
)
# If top_logprobs = 0, still want to see chosen token logprob
# Handle prompt_logprobs
prompt_logprobs_value
=
output_options
.
get
(
"prompt_logprobs"
)
if
prompt_logprobs_value
:
if
hasattr
(
sampling_params
,
"prompt_logprobs"
):
setattr
(
sampling_params
,
"prompt_logprobs"
,
int
(
prompt_logprobs_value
)
)
max_tokens
=
request
[
"stop_conditions"
][
"max_tokens"
]
if
max_tokens
:
sampling_params
.
max_tokens
=
max_tokens
...
...
@@ -302,6 +392,15 @@ class HandlerBase:
out
=
{
"token_ids"
:
output
.
token_ids
[
num_output_tokens_so_far
:]}
# Extract logprobs from the output
log_probs
,
top_logprobs
=
self
.
_extract_logprobs
(
output
,
num_output_tokens_so_far
)
if
log_probs
:
out
[
"log_probs"
]
=
log_probs
if
top_logprobs
:
out
[
"top_logprobs"
]
=
top_logprobs
if
output
.
finish_reason
:
out
[
"finish_reason"
]
=
output
.
finish_reason
if
output
.
stop_reason
:
...
...
tests/serve/test_trtllm.py
View file @
01bfbea1
...
...
@@ -14,7 +14,10 @@ from tests.serve.common import (
)
from
tests.utils.engine_process
import
EngineConfig
from
tests.utils.payload_builder
import
(
TEXT_PROMPT
,
chat_payload
,
chat_payload_default
,
completion_payload
,
completion_payload_default
,
metric_payload_default
,
multimodal_payload_default
,
...
...
@@ -91,6 +94,34 @@ trtllm_configs = {
metric_payload_default
(
port
=
8082
,
min_num_requests
=
6
,
backend
=
"trtllm"
),
],
),
"aggregated_logprobs"
:
TRTLLMConfig
(
name
=
"aggregated_logprobs"
,
directory
=
trtllm_dir
,
script_name
=
"agg.sh"
,
marks
=
[
pytest
.
mark
.
gpu_1
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
trtllm
],
model
=
"Qwen/Qwen3-0.6B"
,
models_port
=
8000
,
request_payloads
=
[
chat_payload
(
content
=
TEXT_PROMPT
,
logprobs
=
True
,
top_logprobs
=
5
),
chat_payload
(
content
=
TEXT_PROMPT
,
logprobs
=
False
,
top_logprobs
=
5
),
chat_payload
(
content
=
TEXT_PROMPT
,
logprobs
=
True
,
top_logprobs
=
None
),
chat_payload
(
content
=
TEXT_PROMPT
,
logprobs
=
True
,
top_logprobs
=
0
),
],
),
"disaggregated_logprobs"
:
TRTLLMConfig
(
name
=
"disaggregated_logprobs"
,
directory
=
trtllm_dir
,
script_name
=
"disagg.sh"
,
marks
=
[
pytest
.
mark
.
gpu_2
,
pytest
.
mark
.
post_merge
,
pytest
.
mark
.
trtllm
],
model
=
"Qwen/Qwen3-0.6B"
,
models_port
=
8000
,
request_payloads
=
[
chat_payload
(
content
=
TEXT_PROMPT
,
logprobs
=
True
,
top_logprobs
=
5
),
chat_payload
(
content
=
TEXT_PROMPT
,
logprobs
=
False
,
top_logprobs
=
5
),
chat_payload
(
content
=
TEXT_PROMPT
,
logprobs
=
True
,
top_logprobs
=
None
),
chat_payload
(
content
=
TEXT_PROMPT
,
logprobs
=
True
,
top_logprobs
=
0
),
],
),
"aggregated_router"
:
TRTLLMConfig
(
name
=
"aggregated_router"
,
directory
=
trtllm_dir
,
...
...
@@ -159,6 +190,7 @@ trtllm_configs = {
},
request_payloads
=
[
completion_payload_default
(),
completion_payload
(
prompt
=
TEXT_PROMPT
,
logprobs
=
3
),
],
),
}
...
...
tests/utils/payload_builder.py
View file @
01bfbea1
...
...
@@ -6,7 +6,9 @@ from typing import Any, Dict, List, Optional, Union
from
tests.utils.client
import
send_request
from
tests.utils.payloads
import
(
ChatPayload
,
ChatPayloadWithLogprobs
,
CompletionPayload
,
CompletionPayloadWithLogprobs
,
EmbeddingPayload
,
MetricsPayload
,
)
...
...
@@ -134,6 +136,8 @@ def chat_payload(
max_tokens
:
int
=
300
,
temperature
:
Optional
[
float
]
=
None
,
stream
:
bool
=
False
,
logprobs
:
bool
=
False
,
top_logprobs
:
Optional
[
int
]
=
None
,
extra_body
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
ChatPayload
:
body
:
Dict
[
str
,
Any
]
=
{
...
...
@@ -145,19 +149,31 @@ def chat_payload(
],
"max_tokens"
:
max_tokens
,
"stream"
:
stream
,
"logprobs"
:
logprobs
,
}
if
temperature
is
not
None
:
body
[
"temperature"
]
=
temperature
if
top_logprobs
is
not
None
:
body
[
"top_logprobs"
]
=
top_logprobs
if
extra_body
:
body
.
update
(
extra_body
)
return
ChatPayload
(
body
=
body
,
repeat_count
=
repeat_count
,
expected_log
=
expected_log
or
[],
expected_response
=
expected_response
or
[],
)
if
logprobs
:
return
ChatPayloadWithLogprobs
(
body
=
body
,
repeat_count
=
repeat_count
,
expected_log
=
expected_log
or
[],
expected_response
=
expected_response
or
[],
)
else
:
return
ChatPayload
(
body
=
body
,
repeat_count
=
repeat_count
,
expected_log
=
expected_log
or
[],
expected_response
=
expected_response
or
[],
)
def
completion_payload
(
...
...
@@ -168,18 +184,29 @@ def completion_payload(
max_tokens
:
int
=
150
,
temperature
:
float
=
0.1
,
stream
:
bool
=
False
,
logprobs
:
Optional
[
int
]
=
None
,
)
->
CompletionPayload
:
return
CompletionPayload
(
body
=
{
"prompt"
:
prompt
,
"max_tokens"
:
max_tokens
,
"temperature"
:
temperature
,
"stream"
:
stream
,
},
repeat_count
=
repeat_count
,
expected_log
=
expected_log
or
[],
expected_response
=
expected_response
or
[],
)
body
:
Dict
[
str
,
Any
]
=
{
"prompt"
:
prompt
,
"max_tokens"
:
max_tokens
,
"temperature"
:
temperature
,
"stream"
:
stream
,
}
if
logprobs
is
not
None
:
body
[
"logprobs"
]
=
logprobs
return
CompletionPayloadWithLogprobs
(
body
=
body
,
repeat_count
=
repeat_count
,
expected_log
=
expected_log
or
[],
expected_response
=
expected_response
or
[],
)
else
:
return
CompletionPayload
(
body
=
body
,
repeat_count
=
repeat_count
,
expected_log
=
expected_log
or
[],
expected_response
=
expected_response
or
[],
)
def
embedding_payload_default
(
...
...
tests/utils/payloads.py
View file @
01bfbea1
...
...
@@ -155,6 +155,39 @@ class ChatPayload(BasePayload):
return
ChatPayload
.
extract_content
(
response
)
@
dataclass
class
ChatPayloadWithLogprobs
(
ChatPayload
):
"""Chat payload that validates logprobs in response."""
def
validate
(
self
,
response
:
Any
,
content
:
str
)
->
None
:
"""Validate response contains logprobs fields."""
super
().
validate
(
response
,
content
)
result
=
response
.
json
()
choice
=
result
[
"choices"
][
0
]
# Validate logprobs field exists
assert
"logprobs"
in
choice
,
"Missing 'logprobs' in choice"
logprobs_data
=
choice
[
"logprobs"
]
if
logprobs_data
is
not
None
:
assert
"content"
in
logprobs_data
,
"Missing 'content' in logprobs"
content_logprobs
=
logprobs_data
[
"content"
]
if
content_logprobs
:
# Validate structure of logprobs
for
item
in
content_logprobs
:
assert
"token"
in
item
,
"Missing 'token' in logprobs content"
assert
"logprob"
in
item
,
"Missing 'logprob' in logprobs content"
assert
(
"top_logprobs"
in
item
),
"Missing 'top_logprobs' in logprobs content"
logger
.
info
(
f
"✓ Logprobs validation passed: found
{
len
(
content_logprobs
)
}
tokens with logprobs"
)
@
dataclass
class
ToolCallingChatPayload
(
ChatPayload
):
"""ChatPayload that validates tool calls in the response."""
...
...
@@ -220,6 +253,39 @@ class CompletionPayload(BasePayload):
return
CompletionPayload
.
extract_text
(
response
)
@
dataclass
class
CompletionPayloadWithLogprobs
(
CompletionPayload
):
"""Completion payload that validates logprobs in response."""
def
validate
(
self
,
response
:
Any
,
content
:
str
)
->
None
:
"""Validate response contains logprobs fields."""
super
().
validate
(
response
,
content
)
result
=
response
.
json
()
choice
=
result
[
"choices"
][
0
]
# Validate logprobs field exists
assert
"logprobs"
in
choice
,
"Missing 'logprobs' in choice"
logprobs_data
=
choice
[
"logprobs"
]
if
logprobs_data
is
not
None
:
assert
(
"token_logprobs"
in
logprobs_data
),
"Missing 'token_logprobs' in logprobs"
assert
"tokens"
in
logprobs_data
,
"Missing 'tokens' in logprobs"
token_logprobs
=
logprobs_data
[
"token_logprobs"
]
tokens
=
logprobs_data
[
"tokens"
]
if
token_logprobs
:
assert
len
(
token_logprobs
)
==
len
(
tokens
),
"Mismatch between token_logprobs and tokens length"
logger
.
info
(
f
"✓ Logprobs validation passed: found
{
len
(
token_logprobs
)
}
tokens with logprobs"
)
@
dataclass
class
EmbeddingPayload
(
BasePayload
):
"""Payload for embeddings endpoint."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment