Unverified Commit 1e44ab95 authored by Yifan Jiang's avatar Yifan Jiang Committed by GitHub
Browse files

fix: guided decoding arg placement, None guard, test cleanup (#6617)


Signed-off-by: default avatarYifan Jiang <19356972+yifjiang@users.noreply.github.com>
parent ab63b8b6
......@@ -198,6 +198,17 @@ class DynamoTrtllmArgGroup(ArgGroup):
help="Maximum size of downloadable embedding files/Image URLs.",
)
# --- Guided Decoding ---
add_argument(
g,
flag_name="--guided-decoding-backend",
env_var="DYN_TRTLLM_GUIDED_DECODING_BACKEND",
default=None,
choices=["xgrammar", "llguidance"],
help="Backend to use for guided decoding (structured output). "
"Options: xgrammar, llguidance.",
)
diffusion_group = parser.add_argument_group(
"Diffusion Options [Experimental]",
"Options for video_diffusion modality",
......@@ -409,17 +420,6 @@ class DynamoTrtllmArgGroup(ArgGroup):
arg_type=int,
help="FSDP size for DiT.",
)
# --- Guided Decoding ---
add_argument(
g,
flag_name="--guided-decoding-backend",
env_var="DYN_TRTLLM_GUIDED_DECODING_BACKEND",
default=None,
choices=["xgrammar", "llguidance"],
help="Backend to use for guided decoding (structured output). "
"Options: xgrammar, llguidance.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-async-cpu-offload",
......
......@@ -883,7 +883,9 @@ class HandlerBase(BaseGenerativeHandler):
regex = guided_decoding.get("regex")
choice = guided_decoding.get("choice")
if choice and not regex:
regex = "(" + "|".join(re.escape(c) for c in choice) + ")"
valid_choices = [c for c in choice if c is not None]
if valid_choices:
regex = "(" + "|".join(re.escape(c) for c in valid_choices) + ")"
overrides["guided_decoding"] = GuidedDecodingParams(
json=guided_decoding.get("json"),
......
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import re as re_mod
from dataclasses import dataclass
from unittest import mock
from unittest.mock import MagicMock
......@@ -245,8 +246,6 @@ class TestGuidedDecodingFromToolChoice:
def test_choice_with_special_chars_escaped(self):
"""Choice values with regex special characters must be escaped."""
import re as re_mod
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
......@@ -297,6 +296,37 @@ class TestGuidedDecodingFromToolChoice:
assert result.guided_decoding.regex is None
def test_choice_with_none_items_filtered(self):
"""Choice list with None items should filter them out."""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": [None, "yes", None, "no"],
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert not isinstance(result.guided_decoding, dict)
assert result.guided_decoding.regex == "(yes|no)"
def test_choice_all_none_items_no_regex(self):
"""Choice list with all None items should not produce a regex."""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": [None, None],
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert result.guided_decoding.regex is None
class _ConcreteHandler(HandlerBase):
"""Concrete subclass of HandlerBase for testing (satisfies abstract method)."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment