Unverified Commit 5602dd2f authored by Yuewei Na's avatar Yuewei Na Committed by GitHub
Browse files

fix: convert guided_decoding dict to GuidedDecodingParams in TRT-LLM handler (#6127)


Signed-off-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
Co-authored-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
parent f9d20c10
...@@ -26,6 +26,7 @@ from tensorrt_llm.executor.result import GenerationResult ...@@ -26,6 +26,7 @@ from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.executor.utils import RequestError from tensorrt_llm.executor.utils import RequestError
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.llm import SamplingParams from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.sampling_params import GuidedDecodingParams
from tensorrt_llm.scheduling_params import SchedulingParams from tensorrt_llm.scheduling_params import SchedulingParams
from dynamo._core import Context from dynamo._core import Context
...@@ -848,6 +849,19 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -848,6 +849,19 @@ class HandlerBase(BaseGenerativeHandler):
if value is not None if value is not None
} }
# Convert guided_decoding dict (from Rust serialization) to GuidedDecodingParams.
# Explicit field mapping avoids breakage if either side adds fields the other
# doesn't know about (e.g. Rust's "backend"/"choice" vs TRT-LLM's fields).
guided_decoding = overrides.pop("guided_decoding", None)
if guided_decoding is not None and isinstance(guided_decoding, dict):
overrides["guided_decoding"] = GuidedDecodingParams(
json=guided_decoding.get("json"),
regex=guided_decoding.get("regex"),
grammar=guided_decoding.get("grammar"),
json_object=guided_decoding.get("json_object", False),
structural_tag=guided_decoding.get("structural_tag"),
)
# NOTE: using `dataclasses.replace` has several benefits over a `setattr` based approach: # NOTE: using `dataclasses.replace` has several benefits over a `setattr` based approach:
# 1. it catches unsupported fields / attributes. # 1. it catches unsupported fields / attributes.
# 2. it executes the class's `__post_init__`, which may contain helpful validation logic. # 2. it executes the class's `__post_init__`, which may contain helpful validation logic.
......
...@@ -26,6 +26,7 @@ class MockSamplingParams: ...@@ -26,6 +26,7 @@ class MockSamplingParams:
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
seed: int | None = None seed: int | None = None
ignore_eos: bool = False ignore_eos: bool = False
guided_decoding: object | None = None
def __post_init__(self): def __post_init__(self):
"""Called after dataclass initialization (including via replace()).""" """Called after dataclass initialization (including via replace())."""
...@@ -150,3 +151,63 @@ class TestOverrideSamplingParams: ...@@ -150,3 +151,63 @@ class TestOverrideSamplingParams:
HandlerBase._override_sampling_params(sampling_params, request) HandlerBase._override_sampling_params(sampling_params, request)
mock_post_init.assert_called_once() mock_post_init.assert_called_once()
class TestGuidedDecodingFromToolChoice:
"""Tests that guided_decoding dicts from Rust are converted to GuidedDecodingParams.
The Rust frontend serializes guided_decoding as a plain dict over TCP.
_override_sampling_params must convert it to a GuidedDecodingParams
object before passing to TRT-LLM, which expects attribute access
(e.g. .json_object, .json) on the guided_decoding field.
"""
# Matches what the Rust frontend serializes when
# tool_choice="required" with a single tool definition.
GUIDED_DECODING_DICT = {
"json": {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [
{
"properties": {
"name": {"type": "string", "enum": ["get_weather"]},
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
},
"required": ["name", "parameters"],
}
],
},
}
}
def test_guided_decoding_dict_is_converted(self):
"""guided_decoding dict from Rust must be converted to GuidedDecodingParams.
The Rust frontend serializes GuidedDecodingOptions as a JSON dict.
_override_sampling_params must convert it to TRT-LLM's
GuidedDecodingParams so that downstream attribute access like
.json_object works without AttributeError.
"""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"temperature": 0.7,
"guided_decoding": self.GUIDED_DECODING_DICT,
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert not isinstance(
result.guided_decoding, dict
), "guided_decoding should be converted from dict to GuidedDecodingParams"
# Downstream code (TRT-LLM sampling_params.py) accesses these attributes:
assert result.guided_decoding.json_object is False
assert result.guided_decoding.json == self.GUIDED_DECODING_DICT["json"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment