Unverified Commit c722d9bd authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix dependency and error message for xgrammar (#2024)

parent 218ab361
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""The baseclass of backends for grammar-guided constrained decoding.""" """The baseclass of a backend for grammar-guided constrained decoding."""
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
......
...@@ -22,7 +22,9 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -22,7 +22,9 @@ from typing import Dict, List, Optional, Tuple, Union
import interegular import interegular
import torch import torch
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.models.transformers import TransformerTokenizer from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend, BaseGrammarBackend,
...@@ -33,26 +35,6 @@ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap ...@@ -33,26 +35,6 @@ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
from outlines.fsm.json_schema import build_regex_from_object
except ImportError:
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
# which only accepts string schema as input.
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
def build_regex_from_object(
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
):
if isinstance(object, type(BaseModel)):
schema = json.dumps(object.model_json_schema())
elif isinstance(object, Dict):
schema = json.dumps(object)
else:
schema = object
return build_regex_from_schema(schema, whitespace_pattern)
class OutlinesGrammar(BaseGrammarObject): class OutlinesGrammar(BaseGrammarObject):
def __init__( def __init__(
self, self,
...@@ -169,3 +151,15 @@ class OutlinesGrammarBackend(BaseGrammarBackend): ...@@ -169,3 +151,15 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
else: else:
jump_forward_map = None jump_forward_map = None
return OutlinesGrammar(guide, jump_forward_map) return OutlinesGrammar(guide, jump_forward_map)
def build_regex_from_object(
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
):
if isinstance(object, type(BaseModel)):
schema = json.dumps(object.model_json_schema())
elif isinstance(object, Dict):
schema = json.dumps(object)
else:
schema = object
return build_regex_from_schema(schema, whitespace_pattern)
...@@ -19,7 +19,16 @@ import logging ...@@ -19,7 +19,16 @@ import logging
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
try:
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
import_error = None
except ImportError as e:
CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
ImportError
)
import_error = e
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend, BaseGrammarBackend,
...@@ -95,10 +104,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -95,10 +104,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
vocab_size: int, vocab_size: int,
): ):
super().__init__() super().__init__()
if import_error:
logger.warning(
f"Ignore import error for the grammar backend: {import_error}"
)
self.grammar_cache = None
return
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer) self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
self.vocab_size = vocab_size self.vocab_size = vocab_size
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
if import_error:
raise import_error
key_type, key_string = key key_type, key_string = key
if key_type == "json": if key_type == "json":
try: try:
...@@ -126,4 +146,5 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -126,4 +146,5 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
return XGrammarGrammar(matcher, self.vocab_size, ctx) return XGrammarGrammar(matcher, self.vocab_size, ctx)
def reset(self): def reset(self):
self.grammar_cache.clear() if self.grammar_cache:
self.grammar_cache.clear()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment