"vscode:/vscode.git/clone" did not exist on "a298ce9c1db2abcc32b69db07683724c9d96f034"
Unverified Commit 122777c8 authored by Yifan Jiang's avatar Yifan Jiang Committed by GitHub
Browse files

fix(trtllm): populate disagg_request_id for PYTHON transceiver (#7604)


Signed-off-by: default avatarYifan Jiang <19356972+yifjiang@users.noreply.github.com>
parent 22da711f
...@@ -27,6 +27,7 @@ import torch ...@@ -27,6 +27,7 @@ import torch
from tensorrt_llm.executor.result import GenerationResult from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.executor.utils import RequestError from tensorrt_llm.executor.utils import RequestError
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.disagg_utils import get_global_disagg_request_id
from tensorrt_llm.llmapi.llm import SamplingParams from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.sampling_params import GuidedDecodingParams from tensorrt_llm.sampling_params import GuidedDecodingParams
from tensorrt_llm.scheduling_params import SchedulingParams from tensorrt_llm.scheduling_params import SchedulingParams
...@@ -83,6 +84,7 @@ class RequestHandlerConfig: ...@@ -83,6 +84,7 @@ class RequestHandlerConfig:
disable_request_abort: bool = True disable_request_abort: bool = True
additional_metrics: Optional["AdditionalMetricsCollector"] = None additional_metrics: Optional["AdditionalMetricsCollector"] = None
max_seq_len: Optional[int] = None max_seq_len: Optional[int] = None
disagg_machine_id: int = 0 # 10-bit machine_id for snowflake disagg_request_id
class HandlerBase(BaseGenerativeHandler): class HandlerBase(BaseGenerativeHandler):
...@@ -114,6 +116,7 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -114,6 +116,7 @@ class HandlerBase(BaseGenerativeHandler):
self.disable_request_abort = config.disable_request_abort self.disable_request_abort = config.disable_request_abort
self.additional_metrics = config.additional_metrics self.additional_metrics = config.additional_metrics
self.max_seq_len = config.max_seq_len self.max_seq_len = config.max_seq_len
self.disagg_machine_id = config.disagg_machine_id
def check_error(self, result: dict) -> bool: def check_error(self, result: dict) -> bool:
""" """
...@@ -465,7 +468,18 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -465,7 +468,18 @@ class HandlerBase(BaseGenerativeHandler):
disaggregated_params = ep_disaggregated_params disaggregated_params = ep_disaggregated_params
else: else:
disaggregated_params = LlmDisaggregatedParams( disaggregated_params = LlmDisaggregatedParams(
request_type="context_only" request_type="context_only",
disagg_request_id=get_global_disagg_request_id(
self.disagg_machine_id
),
)
# Ensure disagg_request_id is set even when using
# ep_disaggregated_params, so the PYTHON transceiver can track
# requests across prefill/decode workers.
if disaggregated_params.disagg_request_id is None:
disaggregated_params.disagg_request_id = get_global_disagg_request_id(
self.disagg_machine_id
) )
# AGGREGATED (prefill_and_decode) mode with encoder disaggregation: # AGGREGATED (prefill_and_decode) mode with encoder disaggregation:
......
...@@ -440,3 +440,80 @@ class TestMultimodalGuard: ...@@ -440,3 +440,80 @@ class TestMultimodalGuard:
assert result["prompt"] == "describe image" assert result["prompt"] == "describe image"
assert result["prompt_token_ids"] == [1, 2, 3] assert result["prompt_token_ids"] == [1, 2, 3]
assert result["multi_modal_data"] is None assert result["multi_modal_data"] is None
class TestDisaggRequestId:
"""Tests for disagg_request_id population in _setup_disaggregated_params_for_mode."""
def _make_prefill_handler(self, machine_id: int = 42) -> HandlerBase:
config = MagicMock()
config.shutdown_event = None
config.disagg_machine_id = machine_id
handler = _ConcreteHandler(config)
handler.disaggregation_mode = DisaggregationMode.PREFILL
return handler
def test_disagg_request_id_populated_in_prefill_mode(self):
"""When mode is PREFILL and no ep_disaggregated_params, disagg_request_id is set."""
handler = self._make_prefill_handler()
disagg_params, _, _ = handler._setup_disaggregated_params_for_mode(
request={}, ep_disaggregated_params=None
)
assert disagg_params is not None
assert disagg_params.disagg_request_id is not None
assert isinstance(disagg_params.disagg_request_id, int)
def test_disagg_request_id_unique_across_calls(self):
"""Multiple calls should produce different IDs."""
handler = self._make_prefill_handler()
ids = set()
for _ in range(10):
params, _, _ = handler._setup_disaggregated_params_for_mode(
request={}, ep_disaggregated_params=None
)
ids.add(params.disagg_request_id)
assert len(ids) == 10, f"Expected 10 unique IDs, got {len(ids)}"
def test_disagg_request_id_set_on_ep_params_with_none(self):
"""When ep_disaggregated_params has disagg_request_id=None, it gets populated."""
handler = self._make_prefill_handler()
ep_params = MagicMock()
ep_params.disagg_request_id = None
# Make bool(ep_params) truthy so the if-branch is taken
ep_params.__bool__ = lambda self: True
params, _, _ = handler._setup_disaggregated_params_for_mode(
request={}, ep_disaggregated_params=ep_params
)
assert params.disagg_request_id is not None
assert isinstance(params.disagg_request_id, int)
def test_disagg_request_id_not_overwritten_when_set(self):
"""When ep_disaggregated_params already has a disagg_request_id, keep it."""
handler = self._make_prefill_handler()
existing_id = 12345678
ep_params = MagicMock()
ep_params.disagg_request_id = existing_id
ep_params.__bool__ = lambda self: True
params, _, _ = handler._setup_disaggregated_params_for_mode(
request={}, ep_disaggregated_params=ep_params
)
assert params.disagg_request_id == existing_id
def test_machine_id_from_config(self):
"""disagg_machine_id is taken from the handler config."""
handler = self._make_prefill_handler(machine_id=123)
assert handler.disagg_machine_id == 123
def test_different_machine_ids_produce_different_id_ranges(self):
"""Handlers with different machine_ids produce non-overlapping snowflake IDs."""
handler_a = self._make_prefill_handler(machine_id=1)
handler_b = self._make_prefill_handler(machine_id=2)
params_a, _, _ = handler_a._setup_disaggregated_params_for_mode(
request={}, ep_disaggregated_params=None
)
params_b, _, _ = handler_b._setup_disaggregated_params_for_mode(
request={}, ep_disaggregated_params=None
)
assert params_a.disagg_request_id != params_b.disagg_request_id
...@@ -497,6 +497,7 @@ async def init_llm_worker( ...@@ -497,6 +497,7 @@ async def init_llm_worker(
disable_request_abort=config.disable_request_abort, disable_request_abort=config.disable_request_abort,
additional_metrics=additional_metrics, additional_metrics=additional_metrics,
max_seq_len=config.max_seq_len, max_seq_len=config.max_seq_len,
disagg_machine_id=int(endpoint.connection_id()) % 1021,
) )
# Register the model with runtime config # Register the model with runtime config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment