Unverified Commit 7734e9a2 authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[Core] choice-based structured output with xgrammar (#12632)

parent 6224a9f6
...@@ -20,7 +20,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer ...@@ -20,7 +20,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11 lm-format-enforcer >= 0.10.9, < 0.11
outlines == 0.1.11 outlines == 0.1.11
lark == 1.2.2 lark == 1.2.2
xgrammar >= 0.1.6; platform_machine == "x86_64" xgrammar >= 0.1.11; platform_machine == "x86_64"
typing_extensions >= 4.10 typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs partial-json-parser # used for parsing partial JSON outputs
......
...@@ -49,11 +49,10 @@ def maybe_backend_fallback( ...@@ -49,11 +49,10 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.") "Falling back to use outlines instead.")
guided_params.backend = "outlines" guided_params.backend = "outlines"
# xgrammar doesn't support regex or choice, fallback to outlines # xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None or guided_params.choice is not None: if guided_params.regex is not None:
logger.warning( logger.warning("xgrammar does not support regex guided decoding. "
"xgrammar only supports json or grammar guided decoding. " "Falling back to use outlines instead.")
"Falling back to use outlines instead.")
guided_params.backend = "outlines" guided_params.backend = "outlines"
# xgrammar doesn't support some JSON schema features # xgrammar doesn't support some JSON schema features
......
...@@ -5,8 +5,9 @@ from __future__ import annotations ...@@ -5,8 +5,9 @@ from __future__ import annotations
import copy import copy
import json import json
import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, List
import torch import torch
from transformers import PreTrainedTokenizerFast from transformers import PreTrainedTokenizerFast
...@@ -228,11 +229,39 @@ class GrammarConfig: ...@@ -228,11 +229,39 @@ class GrammarConfig:
max_threads=max_threads, max_threads=max_threads,
tokenizer_data=tokenizer_data, tokenizer_data=tokenizer_data,
) )
elif guided_params.choice:
choice_str = GrammarConfig.choice_as_grammar(guided_params.choice)
try:
xgr.Grammar.from_ebnf(choice_str)
except RuntimeError as err:
raise ValueError(str(err)) from err
return cls(
grammar_str=choice_str,
vocab_size=model_config.hf_text_config.vocab_size,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
else: else:
raise ValueError( raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar" "Currently only support JSON and EBNF grammar mode for xgrammar"
) )
@staticmethod
def escape_ebnf_string(s: str) -> str:
"""Escape special characters in a EBNF string."""
# Escape double quotes and backslashes
return re.sub(r'(["\\])', r'\\\1', s)
@staticmethod
def choice_as_grammar(choice: List[str] | None) -> str:
if choice is None:
raise ValueError("Choice is not set")
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
return grammar
@dataclass @dataclass
class XGrammarLogitsProcessor: class XGrammarLogitsProcessor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment