Unverified Commit 01ecc8c7 authored by Yifan Jiang's avatar Yifan Jiang Committed by GitHub
Browse files

feat: add guided decoding backend config and choice support for TRT-LLM (#5762)


Signed-off-by: default avatarYifan Jiang <19356972+yifjiang@users.noreply.github.com>
parent 8de1c3e0
......@@ -409,6 +409,17 @@ class DynamoTrtllmArgGroup(ArgGroup):
arg_type=int,
help="FSDP size for DiT.",
)
# --- Guided Decoding ---
add_argument(
g,
flag_name="--guided-decoding-backend",
env_var="DYN_TRTLLM_GUIDED_DECODING_BACKEND",
default=None,
choices=["xgrammar", "llguidance"],
help="Backend to use for guided decoding (structured output). "
"Options: xgrammar, llguidance.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-async-cpu-offload",
......@@ -450,6 +461,7 @@ class DynamoTrtllmConfig(ConfigBase):
override_engine_args: str
publish_events_and_metrics: bool
disable_request_abort: bool
guided_decoding_backend: Optional[str] = None
disaggregation_mode: DisaggregationMode
modality: Modality
......
......@@ -17,6 +17,7 @@ import asyncio
import dataclasses
import logging
import os
import re
from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass
from typing import Any, AsyncGenerator, Optional, Union
......@@ -876,9 +877,17 @@ class HandlerBase(BaseGenerativeHandler):
# doesn't know about (e.g. Rust's "backend"/"choice" vs TRT-LLM's fields).
guided_decoding = overrides.pop("guided_decoding", None)
if guided_decoding is not None and isinstance(guided_decoding, dict):
# TRT-LLM's GuidedDecodingParams doesn't have a "choice" field.
# Convert choice list to a regex pattern: (choice1|choice2|...)
# This matches the approach used by vLLM's outlines backend.
regex = guided_decoding.get("regex")
choice = guided_decoding.get("choice")
if choice and not regex:
regex = "(" + "|".join(re.escape(c) for c in choice) + ")"
overrides["guided_decoding"] = GuidedDecodingParams(
json=guided_decoding.get("json"),
regex=guided_decoding.get("regex"),
regex=regex,
grammar=guided_decoding.get("grammar"),
json_object=guided_decoding.get("json_object", False),
structural_tag=guided_decoding.get("structural_tag"),
......
......@@ -221,6 +221,82 @@ class TestGuidedDecodingFromToolChoice:
assert result.guided_decoding.json_object is False
assert result.guided_decoding.json == self.GUIDED_DECODING_DICT["json"]
def test_choice_converted_to_regex(self):
"""guided_decoding with 'choice' must be converted to a regex pattern.
TRT-LLM's GuidedDecodingParams doesn't have a 'choice' field.
The handler should convert choice=["yes", "no", "maybe"] to
regex="(yes|no|maybe)" so that GuidedDecodingParams can enforce it.
"""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": ["yes", "no", "maybe"],
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert not isinstance(result.guided_decoding, dict)
assert result.guided_decoding.regex == "(yes|no|maybe)"
assert result.guided_decoding.json is None
def test_choice_with_special_chars_escaped(self):
"""Choice values with regex special characters must be escaped."""
import re as re_mod
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": ["yes (confirmed)", "no [rejected]"],
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert not isinstance(result.guided_decoding, dict)
expected = (
"("
+ "|".join(re_mod.escape(c) for c in ["yes (confirmed)", "no [rejected]"])
+ ")"
)
assert result.guided_decoding.regex == expected
def test_choice_not_used_when_regex_present(self):
"""If both choice and regex are specified, regex takes priority."""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": ["yes", "no"],
"regex": "[0-9]+",
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert result.guided_decoding.regex == "[0-9]+"
def test_empty_choice_ignored(self):
"""Empty choice list should not produce a regex."""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": [],
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert result.guided_decoding.regex is None
class _ConcreteHandler(HandlerBase):
"""Concrete subclass of HandlerBase for testing (satisfies abstract method)."""
......
......@@ -188,6 +188,14 @@ async def init_llm_worker(
"kv_connector_config": kv_connector_config,
}
# Add guided decoding backend if specified
if config.guided_decoding_backend is not None:
arg_map["guided_decoding_backend"] = config.guided_decoding_backend
logging.info(
"Guided decoding enabled with backend: %s",
config.guided_decoding_backend,
)
if config.extra_engine_args != "":
# TODO: Support extra engine args from json file as well.
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment