Unverified Commit 840c5dbc authored by Ninglin Du's avatar Ninglin Du Committed by GitHub
Browse files

[FIX] Catch syntax error of Regex Guide to avoid crash (#1521)

parent 63e845d0
...@@ -14,13 +14,17 @@ limitations under the License. ...@@ -14,13 +14,17 @@ limitations under the License.
""" """
"""Cache for the compressed finite state machine.""" """Cache for the compressed finite state machine."""
import logging
from interegular import InvalidSyntax, parse_pattern
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
from transformers import AutoTokenizer from transformers import AutoTokenizer
from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_tool_cache import BaseToolCache from sglang.srt.constrained.base_tool_cache import BaseToolCache
logger = logging.getLogger(__name__)
class FSMCache(BaseToolCache): class FSMCache(BaseToolCache):
def __init__( def __init__(
...@@ -76,5 +80,9 @@ class FSMCache(BaseToolCache): ...@@ -76,5 +80,9 @@ class FSMCache(BaseToolCache):
regex = key_string regex = key_string
else: else:
raise ValueError(f"Invalid key_type: {key_type}") raise ValueError(f"Invalid key_type: {key_type}")
try:
parse_pattern(regex)
except InvalidSyntax as e:
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
return None, regex
return RegexGuide(regex, self.outlines_tokenizer), regex return RegexGuide(regex, self.outlines_tokenizer), regex
...@@ -19,10 +19,12 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ ...@@ -19,10 +19,12 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
""" """
import dataclasses import dataclasses
import logging
from collections import defaultdict from collections import defaultdict
import interegular import interegular
import outlines.caching import outlines.caching
from interegular import InvalidSyntax
from sglang.srt.constrained import ( from sglang.srt.constrained import (
FSMInfo, FSMInfo,
...@@ -34,6 +36,8 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache ...@@ -34,6 +36,8 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class JumpEdge: class JumpEdge:
...@@ -47,7 +51,12 @@ class JumpForwardMap: ...@@ -47,7 +51,12 @@ class JumpForwardMap:
def __init__(self, regex_string): def __init__(self, regex_string):
@disk_cache() @disk_cache()
def _init_state_to_jump_forward(regex_string): def _init_state_to_jump_forward(regex_string):
regex_pattern = interegular.parse_pattern(regex_string) try:
regex_pattern = interegular.parse_pattern(regex_string)
except InvalidSyntax as e:
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
self.state_to_jump_forward = None
return
byte_fsm = make_byte_level_fsm( byte_fsm = make_byte_level_fsm(
regex_pattern.to_fsm().reduce(), keep_utf8=True regex_pattern.to_fsm().reduce(), keep_utf8=True
...@@ -165,7 +174,11 @@ class JumpForwardCache(BaseToolCache): ...@@ -165,7 +174,11 @@ class JumpForwardCache(BaseToolCache):
super().__init__() super().__init__()
def init_value(self, regex): def init_value(self, regex):
return JumpForwardMap(regex) forward_map = JumpForwardMap(regex)
if forward_map.state_to_jump_forward:
return forward_map
else:
return None
def test_main(regex_string): def test_main(regex_string):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment