Unverified Commit b5dcfd41 authored by Lorenzo Lu's avatar Lorenzo Lu Committed by GitHub
Browse files

Add option to disable `any_whitespace` for `xgrammar` and `llguidance` backends. (#8919)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 5061b8fd
...@@ -224,13 +224,17 @@ def create_grammar_backend( ...@@ -224,13 +224,17 @@ def create_grammar_backend(
eos_list = list(eos_token_ids) if eos_token_ids else None eos_list = list(eos_token_ids) if eos_token_ids else None
grammar_backend = XGrammarGrammarBackend( grammar_backend = XGrammarGrammarBackend(
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list tokenizer,
vocab_size=vocab_size,
model_eos_token_ids=eos_list,
any_whitespace=not server_args.constrained_json_disable_any_whitespace,
) )
elif name == "llguidance": elif name == "llguidance":
from sglang.srt.constrained.llguidance_backend import GuidanceBackend from sglang.srt.constrained.llguidance_backend import GuidanceBackend
grammar_backend = GuidanceBackend( grammar_backend = GuidanceBackend(
tokenizer=tokenizer, tokenizer=tokenizer,
any_whitespace=not server_args.constrained_json_disable_any_whitespace,
whitespace_pattern=server_args.constrained_json_whitespace_pattern, whitespace_pattern=server_args.constrained_json_whitespace_pattern,
) )
elif name == "none": elif name == "none":
......
...@@ -110,12 +110,14 @@ class GuidanceBackend(BaseGrammarBackend): ...@@ -110,12 +110,14 @@ class GuidanceBackend(BaseGrammarBackend):
def __init__( def __init__(
self, self,
tokenizer, tokenizer,
any_whitespace: bool = True,
whitespace_pattern: Optional[str] = None, whitespace_pattern: Optional[str] = None,
n_vocab: Optional[int] = None, n_vocab: Optional[int] = None,
): ):
super().__init__() super().__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.any_whitespace = any_whitespace
self.whitespace_pattern = whitespace_pattern self.whitespace_pattern = whitespace_pattern
self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab) self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)
...@@ -134,6 +136,7 @@ class GuidanceBackend(BaseGrammarBackend): ...@@ -134,6 +136,7 @@ class GuidanceBackend(BaseGrammarBackend):
serialized_grammar = LLMatcher.grammar_from_json_schema( serialized_grammar = LLMatcher.grammar_from_json_schema(
key_string, key_string,
defaults={ defaults={
"whitespace_flexible": self.any_whitespace,
"whitespace_pattern": self.whitespace_pattern, "whitespace_pattern": self.whitespace_pattern,
}, },
) )
......
...@@ -115,7 +115,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend): ...@@ -115,7 +115,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
def __init__( def __init__(
self, self,
tokenizer, tokenizer,
whitespace_pattern: bool, whitespace_pattern: str | None,
): ):
super().__init__() super().__init__()
......
...@@ -167,6 +167,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -167,6 +167,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
tokenizer, tokenizer,
vocab_size: int, vocab_size: int,
model_eos_token_ids: Optional[List[int]] = None, model_eos_token_ids: Optional[List[int]] = None,
any_whitespace: bool = True,
): ):
super().__init__() super().__init__()
...@@ -188,6 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -188,6 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens self.override_stop_tokens = override_stop_tokens
self.any_whitespace = any_whitespace
def _from_context( def _from_context(
self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
...@@ -212,7 +214,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -212,7 +214,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
# Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root) # Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
ctx = self.grammar_compiler.compile_builtin_json_grammar() ctx = self.grammar_compiler.compile_builtin_json_grammar()
else: else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string) ctx = self.grammar_compiler.compile_json_schema(
schema=key_string, any_whitespace=self.any_whitespace
)
except (RuntimeError, json.decoder.JSONDecodeError) as e: except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}") logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
......
...@@ -227,6 +227,7 @@ class ServerArgs: ...@@ -227,6 +227,7 @@ class ServerArgs:
stream_output: bool = False stream_output: bool = False
random_seed: Optional[int] = None random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None constrained_json_whitespace_pattern: Optional[str] = None
constrained_json_disable_any_whitespace: bool = False
watchdog_timeout: float = 300 watchdog_timeout: float = 300
dist_timeout: Optional[int] = None # timeout for torch.distributed dist_timeout: Optional[int] = None # timeout for torch.distributed
download_dir: Optional[str] = None download_dir: Optional[str] = None
...@@ -1683,7 +1684,12 @@ class ServerArgs: ...@@ -1683,7 +1684,12 @@ class ServerArgs:
"--constrained-json-whitespace-pattern", "--constrained-json-whitespace-pattern",
type=str, type=str,
default=ServerArgs.constrained_json_whitespace_pattern, default=ServerArgs.constrained_json_whitespace_pattern,
help="(outlines backend only) Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*", help="(outlines and llguidance backends only) Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
)
parser.add_argument(
"--constrained-json-disable-any-whitespace",
action="store_true",
help="(xgrammar and llguidance backends only) Enforce compact representation in JSON constrained output.",
) )
parser.add_argument( parser.add_argument(
"--watchdog-timeout", "--watchdog-timeout",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment