Unverified Commit 7f076c2c authored by Yixin Dong's avatar Yixin Dong Committed by GitHub
Browse files

Update XGrammar to the latest API (#2176)


Co-authored-by: default avatarBen Gitter <gitterbd@gmail.com>
parent 3c5538f7
...@@ -22,7 +22,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", ...@@ -22,7 +22,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
"packaging", "pillow", "prometheus-client>=0.20.0", "packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart", "psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop",
"modelscope", "xgrammar"] "modelscope", "xgrammar==0.1.4"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"] srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
......
...@@ -17,21 +17,14 @@ import logging ...@@ -17,21 +17,14 @@ import logging
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from xgrammar import (
try: CompiledGrammar,
from xgrammar import ( GrammarCompiler,
CachedGrammarCompiler, GrammarMatcher,
CompiledGrammar, TokenizerInfo,
GrammarMatcher, allocate_token_bitmask,
TokenizerInfo, apply_token_bitmask_inplace,
) )
import_error = None
except ImportError as e:
CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
ImportError
)
import_error = e
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend, BaseGrammarBackend,
...@@ -41,7 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import ( ...@@ -41,7 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_ROLLBACK_TOKENS = 10 MAX_ROLLBACK_TOKENS = 200
class XGrammarGrammar(BaseGrammarObject): class XGrammarGrammar(BaseGrammarObject):
...@@ -86,21 +79,22 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -86,21 +79,22 @@ class XGrammarGrammar(BaseGrammarObject):
def allocate_vocab_mask( def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device self, vocab_size: int, batch_size: int, device
) -> torch.Tensor: ) -> torch.Tensor:
return self.matcher.allocate_token_bitmask(vocab_size, batch_size) return allocate_token_bitmask(batch_size, vocab_size)
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(vocab_mask, idx) self.matcher.fill_next_token_bitmask(vocab_mask, idx)
@staticmethod @staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask) if vocab_mask.device.type != logits.device.type:
# vocab_mask must then be on the same device as logits
# when applying the token bitmask, so we check and move if needed
vocab_mask = vocab_mask.to(logits.device)
apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self): def copy(self):
matcher = GrammarMatcher( matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
vocab_size=self.vocab_size,
)
return XGrammarGrammar(matcher, self.vocab_size, self.ctx) return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
...@@ -112,25 +106,18 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -112,25 +106,18 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
): ):
super().__init__() super().__init__()
if import_error: tokenizer_info = TokenizerInfo.from_huggingface(
logger.warning( tokenizer, vocab_size=vocab_size
f"Ignore import error for the grammar backend: {import_error}" )
) self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.grammar_cache = None
return
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size self.vocab_size = vocab_size
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
if import_error:
raise import_error
key_type, key_string = key key_type, key_string = key
if key_type == "json": if key_type == "json":
try: try:
ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string) ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e: except RuntimeError as e:
logging.warning( logging.warning(
f"Skip invalid json_schema: json_schema={key_string}, {e=}" f"Skip invalid json_schema: json_schema={key_string}, {e=}"
...@@ -144,13 +131,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -144,13 +131,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
else: else:
raise ValueError(f"Invalid key_type: {key_type}") raise ValueError(f"Invalid key_type: {key_type}")
matcher = GrammarMatcher( matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
vocab_size=self.vocab_size,
)
return XGrammarGrammar(matcher, self.vocab_size, ctx) return XGrammarGrammar(matcher, self.vocab_size, ctx)
def reset(self): def reset(self):
if self.grammar_cache: if self.grammar_compiler:
self.grammar_cache.clear() self.grammar_compiler.clear_cache()
...@@ -17,7 +17,7 @@ from sglang.test.test_utils import ( ...@@ -17,7 +17,7 @@ from sglang.test.test_utils import (
) )
class TestJSONConstrained(unittest.TestCase): class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
...@@ -36,7 +36,12 @@ class TestJSONConstrained(unittest.TestCase): ...@@ -36,7 +36,12 @@ class TestJSONConstrained(unittest.TestCase):
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=300, timeout=300,
other_args=["--max-running-requests", "10"], other_args=[
"--max-running-requests",
"10",
"--grammar-backend",
"outlines",
],
) )
@classmethod @classmethod
...@@ -121,5 +126,33 @@ class TestJSONConstrained(unittest.TestCase): ...@@ -121,5 +126,33 @@ class TestJSONConstrained(unittest.TestCase):
list(executor.map(self.run_decode, json_schemas)) list(executor.map(self.run_decode, json_schemas))
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=300,
other_args=[
"--max-running-requests",
"10",
"--grammar-backend",
"xgrammar",
],
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment