"lib/ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "b950034bce4b0aa38b0a9c69139901205efb3407"
Unverified Commit 1e44ab95 authored by Yifan Jiang's avatar Yifan Jiang Committed by GitHub
Browse files

fix: guided decoding arg placement, None guard, test cleanup (#6617)


Signed-off-by: default avatarYifan Jiang <19356972+yifjiang@users.noreply.github.com>
parent ab63b8b6
...@@ -198,6 +198,17 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -198,6 +198,17 @@ class DynamoTrtllmArgGroup(ArgGroup):
help="Maximum size of downloadable embedding files/Image URLs.", help="Maximum size of downloadable embedding files/Image URLs.",
) )
# --- Guided Decoding ---
add_argument(
g,
flag_name="--guided-decoding-backend",
env_var="DYN_TRTLLM_GUIDED_DECODING_BACKEND",
default=None,
choices=["xgrammar", "llguidance"],
help="Backend to use for guided decoding (structured output). "
"Options: xgrammar, llguidance.",
)
diffusion_group = parser.add_argument_group( diffusion_group = parser.add_argument_group(
"Diffusion Options [Experimental]", "Diffusion Options [Experimental]",
"Options for video_diffusion modality", "Options for video_diffusion modality",
...@@ -409,17 +420,6 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -409,17 +420,6 @@ class DynamoTrtllmArgGroup(ArgGroup):
arg_type=int, arg_type=int,
help="FSDP size for DiT.", help="FSDP size for DiT.",
) )
# --- Guided Decoding ---
add_argument(
g,
flag_name="--guided-decoding-backend",
env_var="DYN_TRTLLM_GUIDED_DECODING_BACKEND",
default=None,
choices=["xgrammar", "llguidance"],
help="Backend to use for guided decoding (structured output). "
"Options: xgrammar, llguidance.",
)
add_negatable_bool_argument( add_negatable_bool_argument(
diffusion_group, diffusion_group,
flag_name="--enable-async-cpu-offload", flag_name="--enable-async-cpu-offload",
......
...@@ -883,7 +883,9 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -883,7 +883,9 @@ class HandlerBase(BaseGenerativeHandler):
regex = guided_decoding.get("regex") regex = guided_decoding.get("regex")
choice = guided_decoding.get("choice") choice = guided_decoding.get("choice")
if choice and not regex: if choice and not regex:
regex = "(" + "|".join(re.escape(c) for c in choice) + ")" valid_choices = [c for c in choice if c is not None]
if valid_choices:
regex = "(" + "|".join(re.escape(c) for c in valid_choices) + ")"
overrides["guided_decoding"] = GuidedDecodingParams( overrides["guided_decoding"] = GuidedDecodingParams(
json=guided_decoding.get("json"), json=guided_decoding.get("json"),
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import re as re_mod
from dataclasses import dataclass from dataclasses import dataclass
from unittest import mock from unittest import mock
from unittest.mock import MagicMock from unittest.mock import MagicMock
...@@ -245,8 +246,6 @@ class TestGuidedDecodingFromToolChoice: ...@@ -245,8 +246,6 @@ class TestGuidedDecodingFromToolChoice:
def test_choice_with_special_chars_escaped(self): def test_choice_with_special_chars_escaped(self):
"""Choice values with regex special characters must be escaped.""" """Choice values with regex special characters must be escaped."""
import re as re_mod
sampling_params = MockSamplingParams() sampling_params = MockSamplingParams()
request = { request = {
"sampling_options": { "sampling_options": {
...@@ -297,6 +296,37 @@ class TestGuidedDecodingFromToolChoice: ...@@ -297,6 +296,37 @@ class TestGuidedDecodingFromToolChoice:
assert result.guided_decoding.regex is None assert result.guided_decoding.regex is None
def test_choice_with_none_items_filtered(self):
"""Choice list with None items should filter them out."""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": [None, "yes", None, "no"],
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert not isinstance(result.guided_decoding, dict)
assert result.guided_decoding.regex == "(yes|no)"
def test_choice_all_none_items_no_regex(self):
"""Choice list with all None items should not produce a regex."""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"guided_decoding": {
"choice": [None, None],
},
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert result.guided_decoding.regex is None
class _ConcreteHandler(HandlerBase): class _ConcreteHandler(HandlerBase):
"""Concrete subclass of HandlerBase for testing (satisfies abstract method).""" """Concrete subclass of HandlerBase for testing (satisfies abstract method)."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment