Unverified Commit 5bfb30a5 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Fix CFGGuide and use outlines for grammars that can't convert to GBNF (#11389)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
parent e51719ae
...@@ -174,11 +174,6 @@ def test_guided_choice_completion(sample_guided_choice, llm, ...@@ -174,11 +174,6 @@ def test_guided_choice_completion(sample_guided_choice, llm,
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm, def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str): guided_decoding_backend: str):
if guided_decoding_backend == "outlines":
pytest.skip("Outlines backend fails in this test case with:\n"
"AttributeError: Error in model execution: 'ParserConf' "
"object has no attribute 'deterministic'")
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
......
...@@ -3,6 +3,9 @@ from __future__ import annotations ...@@ -3,6 +3,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -15,76 +18,6 @@ if TYPE_CHECKING: ...@@ -15,76 +18,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""
def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
"minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf"
]):
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def has_lmf_unsupported_json_features(schema: dict) -> bool:
"""
Check if JSON schema contains features unsupported
by lm_format_enforcer.
Known issues:
- Regex patterns:
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"""
def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def maybe_backend_fallback( def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams: guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
# lm-format-enforce doesn't support grammar, fallback to xgrammar # lm-format-enforce doesn't support grammar, fallback to xgrammar
...@@ -127,6 +60,20 @@ def maybe_backend_fallback( ...@@ -127,6 +60,20 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.") "Falling back to use outlines instead.")
guided_params.backend = "outlines" guided_params.backend = "outlines"
# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
# grammar is convertible to GBNF
elif (guided_params.grammar is not None
and grammar_is_likely_lark(guided_params.grammar)):
try:
convert_lark_to_gbnf(guided_params.grammar)
except Exception:
logger.warning(
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
if (guided_params.backend == "outlines" if (guided_params.backend == "outlines"
and guided_params.json_object is not None): and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar # outlines doesn't support json_object, fallback to xgrammar
......
...@@ -21,10 +21,11 @@ from typing import Callable, DefaultDict, Dict, List, Union ...@@ -21,10 +21,11 @@ from typing import Callable, DefaultDict, Dict, List, Union
import numpy as np import numpy as np
import torch import torch
from lark import Lark
from outlines import grammars from outlines import grammars
from outlines.caching import cache from outlines.caching import cache
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
RegexGuide, Write)
from outlines.fsm.parsing import PartialLark
from outlines_core.fsm.json_schema import build_regex_from_schema from outlines_core.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -34,7 +35,9 @@ class BaseLogitsProcessor: ...@@ -34,7 +35,9 @@ class BaseLogitsProcessor:
def __init__(self, guide: Guide): def __init__(self, guide: Guide):
self._guide: Guide = guide self._guide: Guide = guide
self._fsm_state: DefaultDict[int, int] = defaultdict(int) # CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
def __call__(self, input_ids: List[int], def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor: scores: torch.Tensor) -> torch.Tensor:
...@@ -54,15 +57,13 @@ class BaseLogitsProcessor: ...@@ -54,15 +57,13 @@ class BaseLogitsProcessor:
# On the first time this is called, we simply re-create # On the first time this is called, we simply re-create
# the Lark object. # the Lark object.
if isinstance(self._guide, CFGGuide): if isinstance(self._guide, CFGGuide):
self._guide.parser = Lark( self._guide.parser = PartialLark(
self._guide.cfg_string, self._guide.cfg_string,
parser="lalr", parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
import_paths=[grammars.GRAMMAR_PATH], import_paths=[grammars.GRAMMAR_PATH],
) )
self._fsm_state[seq_id] = CFGState(
parser_state=self._guide.parser.parse(""), prev_token=None)
instruction = self._guide.get_next_instruction( instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id]) state=self._fsm_state[seq_id])
...@@ -200,7 +201,8 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): ...@@ -200,7 +201,8 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
string = tokenizer.convert_tokens_to_string([token]) string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers # A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": if (type(token) is str and token.startswith(SPIECE_UNDERLINE)
or token == "<0x20>"):
return " " + string return " " + string
return string return string
...@@ -211,6 +213,9 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): ...@@ -211,6 +213,9 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Sync vLLM's decoder with the outlines by returning list.""" """Sync vLLM's decoder with the outlines by returning list."""
def new_decoder(inp_tokens: List[int]) -> List[str]: def new_decoder(inp_tokens: List[int]) -> List[str]:
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
and isinstance(inp_tokens[0], list)):
inp_tokens = inp_tokens[0]
return [decoder(inp_tokens)] return [decoder(inp_tokens)]
return new_decoder return new_decoder
......
import re import re
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""
def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
"minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf"
]):
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def has_lmf_unsupported_json_features(schema: dict) -> bool:
"""
Check if JSON schema contains features unsupported
by lm_format_enforcer.
Known issues:
- Regex patterns:
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"""
def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def grammar_is_likely_lark(grammar_str: str) -> bool: def grammar_is_likely_lark(grammar_str: str) -> bool:
""" """
Check if grammar appears to use Lark syntax. Check if grammar appears to use Lark syntax.
......
...@@ -14,8 +14,8 @@ try: ...@@ -14,8 +14,8 @@ try:
except ImportError: except ImportError:
pass pass
from vllm.model_executor.guided_decoding.xgrammar_utils import ( from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
convert_lark_to_gbnf, grammar_is_likely_lark) grammar_is_likely_lark)
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment