Unverified Commit 01ecc8c7 authored by Yifan Jiang's avatar Yifan Jiang Committed by GitHub
Browse files

feat: add guided decoding backend config and choice support for TRT-LLM (#5762)


Signed-off-by: default avatarYifan Jiang <19356972+yifjiang@users.noreply.github.com>
parent 8de1c3e0
...@@ -409,6 +409,17 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -409,6 +409,17 @@ class DynamoTrtllmArgGroup(ArgGroup):
arg_type=int, arg_type=int,
help="FSDP size for DiT.", help="FSDP size for DiT.",
) )
# --- Guided Decoding ---
add_argument(
g,
flag_name="--guided-decoding-backend",
env_var="DYN_TRTLLM_GUIDED_DECODING_BACKEND",
default=None,
choices=["xgrammar", "llguidance"],
help="Backend to use for guided decoding (structured output). "
"Options: xgrammar, llguidance.",
)
add_negatable_bool_argument( add_negatable_bool_argument(
diffusion_group, diffusion_group,
flag_name="--enable-async-cpu-offload", flag_name="--enable-async-cpu-offload",
...@@ -450,6 +461,7 @@ class DynamoTrtllmConfig(ConfigBase): ...@@ -450,6 +461,7 @@ class DynamoTrtllmConfig(ConfigBase):
override_engine_args: str override_engine_args: str
publish_events_and_metrics: bool publish_events_and_metrics: bool
disable_request_abort: bool disable_request_abort: bool
guided_decoding_backend: Optional[str] = None
disaggregation_mode: DisaggregationMode disaggregation_mode: DisaggregationMode
modality: Modality modality: Modality
......
...@@ -17,6 +17,7 @@ import asyncio ...@@ -17,6 +17,7 @@ import asyncio
import dataclasses import dataclasses
import logging import logging
import os import os
import re
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Any, AsyncGenerator, Optional, Union from typing import Any, AsyncGenerator, Optional, Union
...@@ -876,9 +877,17 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -876,9 +877,17 @@ class HandlerBase(BaseGenerativeHandler):
# doesn't know about (e.g. Rust's "backend"/"choice" vs TRT-LLM's fields). # doesn't know about (e.g. Rust's "backend"/"choice" vs TRT-LLM's fields).
guided_decoding = overrides.pop("guided_decoding", None) guided_decoding = overrides.pop("guided_decoding", None)
if guided_decoding is not None and isinstance(guided_decoding, dict): if guided_decoding is not None and isinstance(guided_decoding, dict):
# TRT-LLM's GuidedDecodingParams doesn't have a "choice" field.
# Convert choice list to a regex pattern: (choice1|choice2|...)
# This matches the approach used by vLLM's outlines backend.
regex = guided_decoding.get("regex")
choice = guided_decoding.get("choice")
if choice and not regex:
regex = "(" + "|".join(re.escape(c) for c in choice) + ")"
overrides["guided_decoding"] = GuidedDecodingParams( overrides["guided_decoding"] = GuidedDecodingParams(
json=guided_decoding.get("json"), json=guided_decoding.get("json"),
regex=guided_decoding.get("regex"), regex=regex,
grammar=guided_decoding.get("grammar"), grammar=guided_decoding.get("grammar"),
json_object=guided_decoding.get("json_object", False), json_object=guided_decoding.get("json_object", False),
structural_tag=guided_decoding.get("structural_tag"), structural_tag=guided_decoding.get("structural_tag"),
......
...@@ -221,6 +221,82 @@ class TestGuidedDecodingFromToolChoice: ...@@ -221,6 +221,82 @@ class TestGuidedDecodingFromToolChoice:
assert result.guided_decoding.json_object is False assert result.guided_decoding.json_object is False
assert result.guided_decoding.json == self.GUIDED_DECODING_DICT["json"] assert result.guided_decoding.json == self.GUIDED_DECODING_DICT["json"]
def test_choice_converted_to_regex(self):
"""guided_decoding with 'choice' must be converted to a regex pattern.
TRT-LLM's GuidedDecodingParams doesn't have a 'choice' field.
The handler should convert choice=["yes", "no", "maybe"] to
regex="(yes|no|maybe)" so that GuidedDecodingParams can enforce it.
"""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": ["yes", "no", "maybe"],
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert not isinstance(result.guided_decoding, dict)
assert result.guided_decoding.regex == "(yes|no|maybe)"
assert result.guided_decoding.json is None
def test_choice_with_special_chars_escaped(self):
"""Choice values with regex special characters must be escaped."""
import re as re_mod
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": ["yes (confirmed)", "no [rejected]"],
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert not isinstance(result.guided_decoding, dict)
expected = (
"("
+ "|".join(re_mod.escape(c) for c in ["yes (confirmed)", "no [rejected]"])
+ ")"
)
assert result.guided_decoding.regex == expected
def test_choice_not_used_when_regex_present(self):
"""If both choice and regex are specified, regex takes priority."""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": ["yes", "no"],
"regex": "[0-9]+",
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert result.guided_decoding.regex == "[0-9]+"
def test_empty_choice_ignored(self):
"""Empty choice list should not produce a regex."""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": [],
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert result.guided_decoding.regex is None
class _ConcreteHandler(HandlerBase): class _ConcreteHandler(HandlerBase):
"""Concrete subclass of HandlerBase for testing (satisfies abstract method).""" """Concrete subclass of HandlerBase for testing (satisfies abstract method)."""
......
...@@ -188,6 +188,14 @@ async def init_llm_worker( ...@@ -188,6 +188,14 @@ async def init_llm_worker(
"kv_connector_config": kv_connector_config, "kv_connector_config": kv_connector_config,
} }
# Add guided decoding backend if specified
if config.guided_decoding_backend is not None:
arg_map["guided_decoding_backend"] = config.guided_decoding_backend
logging.info(
"Guided decoding enabled with backend: %s",
config.guided_decoding_backend,
)
if config.extra_engine_args != "": if config.extra_engine_args != "":
# TODO: Support extra engine args from json file as well. # TODO: Support extra engine args from json file as well.
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args) arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment