Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7f076c2c
Unverified
Commit
7f076c2c
authored
Nov 25, 2024
by
Yixin Dong
Committed by
GitHub
Nov 25, 2024
Browse files
Update XGrammar to the latest API (#2176)
Co-authored-by:
Ben Gitter
<
gitterbd@gmail.com
>
parent
3c5538f7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
45 deletions
+61
-45
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+25
-42
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+35
-2
No files found.
python/pyproject.toml
View file @
7f076c2c
...
...
@@ -22,7 +22,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
"packaging"
,
"pillow"
,
"prometheus-client>=0.20.0"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"pyzmq>=25.1.2"
,
"torchao"
,
"uvicorn"
,
"uvloop"
,
"modelscope"
,
"xgrammar"
]
"modelscope"
,
"xgrammar
==0.1.4
"
]
srt
=
["sglang[runtime_common]
", "
torch
", "
vllm>=
0.6.3
.post
1
"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
...
...
python/sglang/srt/constrained/xgrammar_backend.py
View file @
7f076c2c
...
...
@@ -17,21 +17,14 @@ import logging
from
typing
import
List
,
Tuple
import
torch
try
:
from
xgrammar
import
(
CachedGrammarCompiler
,
CompiledGrammar
,
GrammarMatcher
,
TokenizerInfo
,
)
import_error
=
None
except
ImportError
as
e
:
CachedGrammarCompiler
=
CompiledGrammar
=
GrammarMatcher
=
TokenizerInfo
=
(
ImportError
)
import_error
=
e
from
xgrammar
import
(
CompiledGrammar
,
GrammarCompiler
,
GrammarMatcher
,
TokenizerInfo
,
allocate_token_bitmask
,
apply_token_bitmask_inplace
,
)
from
sglang.srt.constrained.base_grammar_backend
import
(
BaseGrammarBackend
,
...
...
@@ -41,7 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
logger
=
logging
.
getLogger
(
__name__
)
MAX_ROLLBACK_TOKENS
=
1
0
MAX_ROLLBACK_TOKENS
=
20
0
class
XGrammarGrammar
(
BaseGrammarObject
):
...
...
@@ -86,21 +79,22 @@ class XGrammarGrammar(BaseGrammarObject):
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
return
self
.
matcher
.
allocate_token_bitmask
(
vocab
_size
,
batch
_size
)
return
allocate_token_bitmask
(
batch
_size
,
vocab
_size
)
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
self
.
matcher
.
fill_next_token_bitmask
(
vocab_mask
,
idx
)
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
GrammarMatcher
.
apply_token_bitmask_inplace
(
logits
,
vocab_mask
)
if
vocab_mask
.
device
.
type
!=
logits
.
device
.
type
:
# vocab_mask must then be on the same device as logits
# when applying the token bitmask, so we check and move if needed
vocab_mask
=
vocab_mask
.
to
(
logits
.
device
)
apply_token_bitmask_inplace
(
logits
,
vocab_mask
)
def
copy
(
self
):
matcher
=
GrammarMatcher
(
self
.
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
vocab_size
=
self
.
vocab_size
,
)
matcher
=
GrammarMatcher
(
self
.
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
self
.
ctx
)
...
...
@@ -112,25 +106,18 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
):
super
().
__init__
()
if
import_error
:
logger
.
warning
(
f
"Ignore import error for the grammar backend:
{
import_error
}
"
)
self
.
grammar_cache
=
None
return
tokenizer_info
=
TokenizerInfo
.
from_huggingface
(
tokenizer
)
self
.
grammar_cache
=
CachedGrammarCompiler
(
tokenizer_info
=
tokenizer_info
)
tokenizer_info
=
TokenizerInfo
.
from_huggingface
(
tokenizer
,
vocab_size
=
vocab_size
)
self
.
grammar_compiler
=
GrammarCompiler
(
tokenizer_info
=
tokenizer_info
)
self
.
vocab_size
=
vocab_size
def
init_value_impl
(
self
,
key
:
Tuple
[
str
,
str
])
->
XGrammarGrammar
:
if
import_error
:
raise
import_error
key_type
,
key_string
=
key
if
key_type
==
"json"
:
try
:
ctx
=
self
.
grammar_c
ache
.
compile_json_schema
_grammar
(
schema
=
key_string
)
ctx
=
self
.
grammar_c
ompiler
.
compile_json_schema
(
schema
=
key_string
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
...
...
@@ -144,13 +131,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
else
:
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
matcher
=
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
vocab_size
=
self
.
vocab_size
,
)
matcher
=
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
)
def
reset
(
self
):
if
self
.
grammar_c
ache
:
self
.
grammar_c
ache
.
clear
()
if
self
.
grammar_c
ompiler
:
self
.
grammar_c
ompiler
.
clear_cache
()
test/srt/test_json_constrained.py
View file @
7f076c2c
...
...
@@ -17,7 +17,7 @@ from sglang.test.test_utils import (
)
class
TestJSONConstrained
(
unittest
.
TestCase
):
class
TestJSONConstrained
OutlinesBackend
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
...
@@ -36,7 +36,12 @@ class TestJSONConstrained(unittest.TestCase):
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--max-running-requests"
,
"10"
],
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
"outlines"
,
],
)
@
classmethod
...
...
@@ -121,5 +126,33 @@ class TestJSONConstrained(unittest.TestCase):
list
(
executor
.
map
(
self
.
run_decode
,
json_schemas
))
class
TestJSONConstrainedXGrammarBackend
(
TestJSONConstrainedOutlinesBackend
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
"xgrammar"
,
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment