Unverified Commit 01bfbea1 authored by Elijah Soba's avatar Elijah Soba Committed by GitHub
Browse files

feat: Add logprobs support to TRTLLM backend (#4759)


Signed-off-by: default avatarElijah Soba <esoba@nvidia.com>
parent 1e5b20b2
......@@ -106,6 +106,76 @@ class HandlerBase:
result["finish_reason"] == "stop" or result["finish_reason"] == "error"
)
@staticmethod
def _extract_logprobs(
output, num_output_tokens_so_far: int
) -> tuple[list[float] | None, list[list[dict]] | None]:
"""
Extract logprobs from the TRTLLM output for new tokens.
Args:
output: TRTLLM CompletionOutput object
num_output_tokens_so_far: Number of tokens already processed
Returns:
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
- log_probs: List of log probabilities for each new token
- top_logprobs: List of top logprobs dicts for each new token
"""
if output.logprobs is None:
return None, None
# Get logprobs for new tokens only
new_logprobs = output.logprobs[num_output_tokens_so_far:]
if not new_logprobs:
return None, None
# From TRTLLM CompletionOutput API, logprobs: (TokenLogprobs | List[float], optional)
# Expect TokenLogprobs output when logprobs is set, check edge case where list[float] is returned instead
if isinstance(new_logprobs[0], float):
return [float(lp) for lp in new_logprobs], None
log_probs = []
top_logprobs = []
for token_idx, token_logprobs_dict in enumerate(new_logprobs):
if token_logprobs_dict is None:
continue
# Get the actual token_id that was generated at this position
actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx]
# Extract log probability for the selected token
if actual_token_id in token_logprobs_dict:
selected_logprob = token_logprobs_dict[actual_token_id]
log_probs.append(float(selected_logprob.logprob))
else:
# Fallback: use the first logprob if selected token not found
first_logprob = next(iter(token_logprobs_dict.values()), None)
if first_logprob:
log_probs.append(float(first_logprob.logprob))
# Build top_logprobs list for this token position
# NOTE: TRTLLM LogProb API doesn't have decoded_token, will default to None
token_top_logprobs = []
for tok_id, logprob_info in token_logprobs_dict.items():
token_top_logprobs.append(
{
"rank": logprob_info.rank
if hasattr(logprob_info, "rank")
else 0,
"token_id": tok_id,
"token": (
logprob_info.decoded_token
if hasattr(logprob_info, "decoded_token")
else None
),
"logprob": float(logprob_info.logprob),
}
)
top_logprobs.append(token_top_logprobs)
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
async def _handle_cancellation(
self, generation_result: GenerationResult, context: Context
):
......@@ -236,6 +306,26 @@ class HandlerBase:
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
# Additional sampling params in output options
output_options = request.get("output_options", {})
if output_options:
logprobs_value = output_options.get("logprobs")
# Handle logprobs
if logprobs_value is not None:
if hasattr(sampling_params, "logprobs"):
setattr(
sampling_params, "logprobs", max(1, int(logprobs_value))
) # If top_logprobs = 0, still want to see chosen token logprob
# Handle prompt_logprobs
prompt_logprobs_value = output_options.get("prompt_logprobs")
if prompt_logprobs_value:
if hasattr(sampling_params, "prompt_logprobs"):
setattr(
sampling_params, "prompt_logprobs", int(prompt_logprobs_value)
)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
......@@ -302,6 +392,15 @@ class HandlerBase:
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
# Extract logprobs from the output
log_probs, top_logprobs = self._extract_logprobs(
output, num_output_tokens_so_far
)
if log_probs:
out["log_probs"] = log_probs
if top_logprobs:
out["top_logprobs"] = top_logprobs
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
......
......@@ -14,7 +14,10 @@ from tests.serve.common import (
)
from tests.utils.engine_process import EngineConfig
from tests.utils.payload_builder import (
TEXT_PROMPT,
chat_payload,
chat_payload_default,
completion_payload,
completion_payload_default,
metric_payload_default,
multimodal_payload_default,
......@@ -91,6 +94,34 @@ trtllm_configs = {
metric_payload_default(port=8082, min_num_requests=6, backend="trtllm"),
],
),
"aggregated_logprobs": TRTLLMConfig(
name="aggregated_logprobs",
directory=trtllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.trtllm],
model="Qwen/Qwen3-0.6B",
models_port=8000,
request_payloads=[
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=5),
chat_payload(content=TEXT_PROMPT, logprobs=False, top_logprobs=5),
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=None),
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=0),
],
),
"disaggregated_logprobs": TRTLLMConfig(
name="disaggregated_logprobs",
directory=trtllm_dir,
script_name="disagg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.post_merge, pytest.mark.trtllm],
model="Qwen/Qwen3-0.6B",
models_port=8000,
request_payloads=[
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=5),
chat_payload(content=TEXT_PROMPT, logprobs=False, top_logprobs=5),
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=None),
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=0),
],
),
"aggregated_router": TRTLLMConfig(
name="aggregated_router",
directory=trtllm_dir,
......@@ -159,6 +190,7 @@ trtllm_configs = {
},
request_payloads=[
completion_payload_default(),
completion_payload(prompt=TEXT_PROMPT, logprobs=3),
],
),
}
......
......@@ -6,7 +6,9 @@ from typing import Any, Dict, List, Optional, Union
from tests.utils.client import send_request
from tests.utils.payloads import (
ChatPayload,
ChatPayloadWithLogprobs,
CompletionPayload,
CompletionPayloadWithLogprobs,
EmbeddingPayload,
MetricsPayload,
)
......@@ -134,6 +136,8 @@ def chat_payload(
max_tokens: int = 300,
temperature: Optional[float] = None,
stream: bool = False,
logprobs: bool = False,
top_logprobs: Optional[int] = None,
extra_body: Optional[Dict[str, Any]] = None,
) -> ChatPayload:
body: Dict[str, Any] = {
......@@ -145,19 +149,31 @@ def chat_payload(
],
"max_tokens": max_tokens,
"stream": stream,
"logprobs": logprobs,
}
if temperature is not None:
body["temperature"] = temperature
if top_logprobs is not None:
body["top_logprobs"] = top_logprobs
if extra_body:
body.update(extra_body)
return ChatPayload(
body=body,
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
)
if logprobs:
return ChatPayloadWithLogprobs(
body=body,
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
)
else:
return ChatPayload(
body=body,
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
)
def completion_payload(
......@@ -168,18 +184,29 @@ def completion_payload(
max_tokens: int = 150,
temperature: float = 0.1,
stream: bool = False,
logprobs: Optional[int] = None,
) -> CompletionPayload:
return CompletionPayload(
body={
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
)
body: Dict[str, Any] = {
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream,
}
if logprobs is not None:
body["logprobs"] = logprobs
return CompletionPayloadWithLogprobs(
body=body,
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
)
else:
return CompletionPayload(
body=body,
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
)
def embedding_payload_default(
......
......@@ -155,6 +155,39 @@ class ChatPayload(BasePayload):
return ChatPayload.extract_content(response)
@dataclass
class ChatPayloadWithLogprobs(ChatPayload):
"""Chat payload that validates logprobs in response."""
def validate(self, response: Any, content: str) -> None:
"""Validate response contains logprobs fields."""
super().validate(response, content)
result = response.json()
choice = result["choices"][0]
# Validate logprobs field exists
assert "logprobs" in choice, "Missing 'logprobs' in choice"
logprobs_data = choice["logprobs"]
if logprobs_data is not None:
assert "content" in logprobs_data, "Missing 'content' in logprobs"
content_logprobs = logprobs_data["content"]
if content_logprobs:
# Validate structure of logprobs
for item in content_logprobs:
assert "token" in item, "Missing 'token' in logprobs content"
assert "logprob" in item, "Missing 'logprob' in logprobs content"
assert (
"top_logprobs" in item
), "Missing 'top_logprobs' in logprobs content"
logger.info(
f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs"
)
@dataclass
class ToolCallingChatPayload(ChatPayload):
"""ChatPayload that validates tool calls in the response."""
......@@ -220,6 +253,39 @@ class CompletionPayload(BasePayload):
return CompletionPayload.extract_text(response)
@dataclass
class CompletionPayloadWithLogprobs(CompletionPayload):
"""Completion payload that validates logprobs in response."""
def validate(self, response: Any, content: str) -> None:
"""Validate response contains logprobs fields."""
super().validate(response, content)
result = response.json()
choice = result["choices"][0]
# Validate logprobs field exists
assert "logprobs" in choice, "Missing 'logprobs' in choice"
logprobs_data = choice["logprobs"]
if logprobs_data is not None:
assert (
"token_logprobs" in logprobs_data
), "Missing 'token_logprobs' in logprobs"
assert "tokens" in logprobs_data, "Missing 'tokens' in logprobs"
token_logprobs = logprobs_data["token_logprobs"]
tokens = logprobs_data["tokens"]
if token_logprobs:
assert len(token_logprobs) == len(
tokens
), "Mismatch between token_logprobs and tokens length"
logger.info(
f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs"
)
@dataclass
class EmbeddingPayload(BasePayload):
"""Payload for embeddings endpoint."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment