Unverified Commit 218ab361 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Do not let invalid grammar crash the server (#2023)

parent f407fcf9
......@@ -52,7 +52,7 @@ class BaseGrammarBackend:
else:
entry.value = self.init_value_impl(key)
entry.event.set()
return entry.value.copy()
return entry.value.copy() if entry.value else None
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
raise NotImplementedError()
......@@ -62,7 +62,8 @@ class BaseGrammarBackend:
entry = self.cache.get(key)
if not entry or not entry.event.is_set():
return None
return self.cache[key].value.copy()
val = self.cache[key].value
return val.copy() if val else None
def get_future_value(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self.init_value, key)
......
......@@ -19,6 +19,7 @@ import json
import logging
from typing import Dict, List, Optional, Tuple, Union
import interegular
import torch
from outlines.fsm.guide import RegexGuide
from outlines.models.transformers import TransformerTokenizer
......@@ -147,17 +148,22 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
key_string,
whitespace_pattern=self.whitespace_pattern,
)
except NotImplementedError as e:
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
logger.warning(
f"skip invalid json schema: json_schema={key_string}, {e=}"
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
)
return None, key_string
return None
elif key_type == "regex":
regex = key_string
else:
raise ValueError(f"Invalid key_type: {key_type}")
guide = RegexGuide(regex, self.outlines_tokenizer)
try:
guide = RegexGuide(regex, self.outlines_tokenizer)
except interegular.patterns.InvalidSyntax as e:
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
return None
if self.allow_jump_forward:
jump_forward_map = OutlinesJumpForwardMap(regex)
else:
......
......@@ -15,6 +15,7 @@ limitations under the License.
"""Constrained decoding with xgrammar backend."""
import logging
from typing import List, Tuple
import torch
......@@ -25,6 +26,9 @@ from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarObject,
)
logger = logging.getLogger(__name__)
MAX_ROLLBACK_TOKENS = 10
......@@ -97,9 +101,20 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
key_type, key_string = key
if key_type == "json":
ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
try:
ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(
key_string
)
except RuntimeError as e:
logging.warning(
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
)
return None
elif key_type == "regex":
raise ValueError("regex hasn't been supported by xgrammar yet")
logger.warning(
"regex hasn't been supported by xgrammar yet. This is skipped."
)
return None
else:
raise ValueError(f"Invalid key_type: {key_type}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment