Unverified Commit e930526b authored by Yuewei Na's avatar Yuewei Na Committed by GitHub
Browse files

fix: serialize disagg first_gen_log_probs int keys for Rust transport (#7145)


Signed-off-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
Co-authored-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
parent 761e848b
...@@ -294,6 +294,10 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -294,6 +294,10 @@ class HandlerBase(BaseGenerativeHandler):
# Remove worker_id if present (added by prefill worker, not needed for decode) # Remove worker_id if present (added by prefill worker, not needed for decode)
params_dict.pop("worker_id", None) params_dict.pop("worker_id", None)
# Deserialize first_gen_log_probs from transport format back to
# TRT-LLM's internal {token_id: Logprob} dict format.
DisaggregatedParamsCodec.deserialize_first_gen_log_probs(params_dict)
# Extract EPD metadata that was packed by prefill worker # Extract EPD metadata that was packed by prefill worker
epd_metadata = {} epd_metadata = {}
if "_epd_metadata" in params_dict: if "_epd_metadata" in params_dict:
...@@ -385,6 +389,9 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -385,6 +389,9 @@ class HandlerBase(BaseGenerativeHandler):
logging.debug("PREFILL: Successfully encoded disaggregated params") logging.debug("PREFILL: Successfully encoded disaggregated params")
params_dict = asdict(encoded_params) params_dict = asdict(encoded_params)
# Serialize first_gen_log_probs for the Rust transport layer.
DisaggregatedParamsCodec.serialize_first_gen_log_probs(params_dict)
# Pack prefill metadata for DECODE worker optimization # Pack prefill metadata for DECODE worker optimization
# The frontend only forwards disaggregated_params from prefill response # The frontend only forwards disaggregated_params from prefill response
# Note: max_tokens is already handled by Rust frontend's PrefillRouter # Note: max_tokens is already handled by Rust frontend's PrefillRouter
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import base64 import base64
import dataclasses import dataclasses
from tensorrt_llm.executor.result import Logprob
from tensorrt_llm.llmapi import DisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams
...@@ -24,6 +25,58 @@ class DisaggregatedParamsCodec: ...@@ -24,6 +25,58 @@ class DisaggregatedParamsCodec:
Codec for encoding and decoding disaggregated params for network transfer. Codec for encoding and decoding disaggregated params for network transfer.
""" """
@staticmethod
def serialize_first_gen_log_probs(params_dict: dict) -> None:
"""Convert first_gen_log_probs from TRT-LLM's internal format to a
JSON-safe transport format.
TRT-LLM stores logprobs as ``[{token_id(int): Logprob, ...}, ...]``
where dict keys are integer token IDs. The Rust transport layer
(pythonize 0.23 → serde_json::Value) requires string map keys, so
we flatten to a list-of-lists format matching TRT-LLM's own
``_serialize_first_gen_log_probs`` in ``openai_protocol.py``::
Input: [{4710: Logprob(-2.32, rank=1), 6771: Logprob(-2.51, rank=2)}]
Output: [[{"token_id": 4710, "logprob": -2.32, "rank": 1},
{"token_id": 6771, "logprob": -2.51, "rank": 2}]]
"""
fglp = params_dict.get("first_gen_log_probs")
if not fglp:
return
params_dict["first_gen_log_probs"] = [
[
{"token_id": tid, "logprob": lp["logprob"], "rank": lp.get("rank")}
for tid, lp in pos.items()
]
if isinstance(pos, dict)
else pos
for pos in fglp
]
@staticmethod
def deserialize_first_gen_log_probs(params_dict: dict) -> None:
"""Reconstruct first_gen_log_probs from the JSON-safe transport format
back to TRT-LLM's internal ``{token_id(int): Logprob}`` dict format.
TRT-LLM's ``py_executor.py`` calls ``append_log_probs`` which accesses
the ``.logprob`` attribute on the dict values, so we must rebuild
``Logprob`` dataclass instances.
"""
fglp = params_dict.get("first_gen_log_probs")
if not fglp:
return
params_dict["first_gen_log_probs"] = [
{
item["token_id"]: Logprob(
logprob=item["logprob"], rank=item.get("rank")
)
for item in pos
}
if isinstance(pos, list)
else pos
for pos in fglp
]
@staticmethod @staticmethod
def decode( def decode(
disaggregated_params: DisaggregatedParams, disaggregated_params: DisaggregatedParams,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for disaggregated logprobs serialization round-trip.
TRT-LLM PR #11727 adds first_gen_log_probs to DisaggregatedParams with
integer token-ID dict keys ({4710: Logprob(...)}). The Dynamo Rust
transport layer (pythonize 0.23 → serde_json::Value) requires string map
keys. These tests verify the codec correctly converts between TRT-LLM's
internal format and a JSON-safe transport format.
Mirrors TRT-LLM's TestFirstGenLogProbsSerializeRoundtrip in
tests/unittest/disaggregated/test_openai_disagg_service.py.
"""
import dataclasses
import json
import pytest
try:
from tensorrt_llm.executor.result import Logprob
from dynamo.trtllm.utils.disagg_utils import DisaggregatedParamsCodec
except ImportError as e:
pytest.skip(f"tensorrt_llm import failed: {e}", allow_module_level=True)
def _to_asdict_format(logprob_dicts):
"""Convert [{int: Logprob}] to [{int: dict}] to match what
dataclasses.asdict() produces in the production flow."""
return [
{tid: dataclasses.asdict(lp) for tid, lp in pos.items()}
for pos in logprob_dicts
]
@pytest.mark.pre_merge
@pytest.mark.trtllm
@pytest.mark.gpu_0
@pytest.mark.unit
class TestDisaggLogprobsSerializationRoundtrip:
"""Roundtrip tests for first_gen_log_probs serialize/deserialize."""
def test_none_passthrough(self):
params = {"first_gen_log_probs": None}
DisaggregatedParamsCodec.serialize_first_gen_log_probs(params)
assert params["first_gen_log_probs"] is None
DisaggregatedParamsCodec.deserialize_first_gen_log_probs(params)
assert params["first_gen_log_probs"] is None
def test_missing_field_noop(self):
params = {"request_type": "context_only"}
DisaggregatedParamsCodec.serialize_first_gen_log_probs(params)
assert "first_gen_log_probs" not in params
DisaggregatedParamsCodec.deserialize_first_gen_log_probs(params)
assert "first_gen_log_probs" not in params
def test_single_token_roundtrip(self):
original = [{4710: Logprob(logprob=-2.3256, rank=1)}]
# In production, dataclasses.asdict() converts Logprob → dict before
# serialize is called. Mimic that here.
params = {"first_gen_log_probs": _to_asdict_format(original)}
DisaggregatedParamsCodec.serialize_first_gen_log_probs(params)
# Verify serialized format: list of lists of dicts (no int dict keys)
serialized = params["first_gen_log_probs"]
assert isinstance(serialized, list)
assert isinstance(serialized[0], list)
assert serialized[0][0]["token_id"] == 4710
assert serialized[0][0]["logprob"] == pytest.approx(-2.3256)
assert serialized[0][0]["rank"] == 1
# Verify JSON-safe (this is the actual failure point — int dict keys
# cause pythonize 0.23 depythonize to fail with dict_key_not_string)
json.dumps(params)
# Round-trip back
DisaggregatedParamsCodec.deserialize_first_gen_log_probs(params)
recovered = params["first_gen_log_probs"]
assert len(recovered) == 1
assert 4710 in recovered[0]
assert isinstance(recovered[0][4710], Logprob)
assert recovered[0][4710].logprob == pytest.approx(-2.3256)
assert recovered[0][4710].rank == 1
def test_multi_token_topk_roundtrip(self):
original = [
{
100: Logprob(logprob=-0.1, rank=1),
200: Logprob(logprob=-2.3, rank=2),
300: Logprob(logprob=-5.0, rank=3),
},
{
400: Logprob(logprob=-0.05, rank=1),
500: Logprob(logprob=-3.7, rank=2),
},
]
params = {"first_gen_log_probs": _to_asdict_format(original)}
DisaggregatedParamsCodec.serialize_first_gen_log_probs(params)
json.dumps(params) # Must be JSON-safe
DisaggregatedParamsCodec.deserialize_first_gen_log_probs(params)
recovered = params["first_gen_log_probs"]
assert len(recovered) == 2
assert set(recovered[0].keys()) == {100, 200, 300}
assert set(recovered[1].keys()) == {400, 500}
for orig_pos, rec_pos in zip(original, recovered, strict=True):
for tid in orig_pos:
assert rec_pos[tid].logprob == pytest.approx(orig_pos[tid].logprob)
assert rec_pos[tid].rank == orig_pos[tid].rank
def test_rank_none_preserved(self):
original = _to_asdict_format([{42: Logprob(logprob=-1.0, rank=None)}])
params = {"first_gen_log_probs": original}
DisaggregatedParamsCodec.serialize_first_gen_log_probs(params)
DisaggregatedParamsCodec.deserialize_first_gen_log_probs(params)
assert params["first_gen_log_probs"][0][42].rank is None
def test_empty_list_passthrough(self):
params = {"first_gen_log_probs": []}
DisaggregatedParamsCodec.serialize_first_gen_log_probs(params)
assert params["first_gen_log_probs"] == []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment