Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ba069a24
Unverified
Commit
ba069a24
authored
Nov 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 12, 2024
Browse files
Fix grammar backend (#2018)
parent
125b1199
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
401 additions
and
263 deletions
+401
-263
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+2
-27
python/sglang/srt/constrained/base_tool_cache.py
python/sglang/srt/constrained/base_tool_cache.py
+1
-3
python/sglang/srt/constrained/grammar.py
python/sglang/srt/constrained/grammar.py
+0
-182
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+203
-0
python/sglang/srt/constrained/outlines_jump_forward.py
python/sglang/srt/constrained/outlines_jump_forward.py
+1
-1
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+127
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-7
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+38
-28
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+2
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-6
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+16
-6
No files found.
python/sglang/srt/constrained/__init__.py
View file @
ba069a24
...
...
@@ -13,30 +13,5 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import
json
from
typing
import
Dict
,
Optional
,
Union
from
pydantic
import
BaseModel
try
:
from
outlines.fsm.json_schema
import
build_regex_from_object
except
ImportError
:
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
# which only accepts string schema as input.
from
outlines.fsm.json_schema
import
build_regex_from_schema
def
build_regex_from_object
(
object
:
Union
[
str
,
BaseModel
,
Dict
],
whitespace_pattern
:
Optional
[
str
]
=
None
):
if
isinstance
(
object
,
type
(
BaseModel
)):
schema
=
json
.
dumps
(
object
.
model_json_schema
())
elif
isinstance
(
object
,
Dict
):
schema
=
json
.
dumps
(
object
)
else
:
schema
=
object
return
build_regex_from_schema
(
schema
,
whitespace_pattern
)
__all__
=
[
"build_regex_from_object"
,
]
# TODO(lmzheng): make this an optional dependency
from
sglang.srt.constrained.outlines_backend
import
build_regex_from_object
python/sglang/srt/constrained/base_tool_cache.py
View file @
ba069a24
...
...
@@ -95,9 +95,7 @@ class BaseToolCache:
def
get_cache_hit_rate
(
self
):
with
self
.
lock_metrics
:
if
self
.
metrics
[
"total"
]
==
0
:
return
0
return
self
.
metrics
[
"hit"
]
/
self
.
metrics
[
"total"
]
return
self
.
metrics
[
"hit"
]
/
max
(
self
.
metrics
[
"total"
],
1
)
def
get_avg_init_time
(
self
):
with
self
.
lock_metrics
:
...
...
python/sglang/srt/constrained/grammar.py
deleted
100644 → 0
View file @
125b1199
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
import
logging
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
typing
import
List
,
Tuple
,
Union
import
torch
from
sglang.srt.constrained.outlines_cache
import
OutlinesCache
,
RegexGuide
from
sglang.srt.constrained.outlines_jump_forward
import
(
OutlinesJumpCache
,
OutlinesJumpForwardMap
,
)
from
sglang.srt.constrained.xgrammar_cache
import
(
GrammarMatcher
,
XGrammarBackend
,
XGrammarJumpCache
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
JumpHelper
:
def
__init__
(
self
,
data
:
Union
[
List
,
str
]
=
""
,
state
:
int
=
-
1
,
suffix_ids
=
[]
)
->
None
:
self
.
data
:
Union
[
List
,
str
]
=
data
self
.
state
:
int
=
state
self
.
suffix_ids
:
List
[
int
]
=
suffix_ids
def
can_jump
(
self
):
return
len
(
self
.
data
)
>
0
class
Grammar
:
def
__init__
(
self
,
grammar
:
Union
[
GrammarMatcher
,
Tuple
[
RegexGuide
,
int
]],
jump_map
:
Union
[
XGrammarJumpCache
,
OutlinesJumpForwardMap
,
None
],
)
->
None
:
self
.
grammar
=
grammar
self
.
jump_map
=
jump_map
def
accept_token
(
self
,
token
:
int
):
if
isinstance
(
self
.
grammar
,
GrammarMatcher
):
assert
self
.
grammar
.
accept_token
(
token
)
else
:
guide
,
state
=
self
.
grammar
self
.
grammar
=
guide
,
guide
.
get_next_state
(
state
,
token
)
def
try_jump
(
self
,
tokenizer
)
->
JumpHelper
:
if
isinstance
(
self
.
jump_map
,
XGrammarJumpCache
):
assert
isinstance
(
self
.
grammar
,
GrammarMatcher
)
return
JumpHelper
(
self
.
grammar
.
find_jump_forward_string
())
elif
isinstance
(
self
.
jump_map
,
OutlinesJumpForwardMap
):
assert
isinstance
(
self
.
grammar
,
Tuple
)
_
,
state
=
self
.
grammar
jump_forward_bytes
=
self
.
jump_map
.
jump_forward_byte
(
state
)
if
jump_forward_bytes
is
None
or
len
(
jump_forward_bytes
)
==
0
:
return
JumpHelper
()
# can't jump
# preprocess the jump forward string
suffix_bytes
=
[]
continuation_range
=
range
(
0x80
,
0xC0
)
cur_state
=
state
while
(
len
(
jump_forward_bytes
)
and
jump_forward_bytes
[
0
][
0
]
in
continuation_range
):
# continuation bytes
byte_edge
=
jump_forward_bytes
.
pop
(
0
)
suffix_bytes
.
append
(
byte_edge
[
0
])
cur_state
=
byte_edge
[
1
]
suffix_tokens
=
[
f
"<0x
{
hex
(
b
)[
2
:].
upper
()
}
>"
for
b
in
suffix_bytes
]
suffix_ids
=
tokenizer
.
convert_tokens_to_ids
(
suffix_tokens
)
return
JumpHelper
(
suffix_ids
,
cur_state
,
suffix_bytes
)
else
:
return
JumpHelper
()
# can't jump
def
jump_forward_str_state
(
self
,
helper
:
JumpHelper
)
->
Tuple
[
str
,
int
]:
if
isinstance
(
helper
.
data
,
str
):
return
helper
.
data
,
-
1
else
:
assert
isinstance
(
self
.
jump_map
,
OutlinesJumpForwardMap
)
return
self
.
jump_map
.
jump_forward_symbol
(
helper
.
state
)
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
if
isinstance
(
self
.
grammar
,
GrammarMatcher
):
k
=
0
for
i
,
old_id
in
enumerate
(
old_output_ids
):
if
old_id
==
new_output_ids
[
i
]:
k
=
i
+
1
else
:
break
# rollback to the last token that is the same
if
k
<
len
(
old_output_ids
):
self
.
grammar
.
rollback
(
len
(
old_output_ids
)
-
k
)
for
i
in
range
(
k
,
len
(
new_output_ids
)):
assert
self
.
grammar
.
accept_token
(
new_output_ids
[
i
])
else
:
self
.
grammar
=
self
.
grammar
[
0
],
next_state
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
vocab_size
:
int
):
if
isinstance
(
self
.
grammar
,
GrammarMatcher
):
# Note that this bitmask is a bitset, not bool
bitmask
=
self
.
grammar
.
get_next_token_bitmask
()
# Mask the tokens that are not allowed
vocab_mask
[
self
.
grammar
.
get_rejected_tokens_from_bitmask
(
bitmask
,
vocab_size
)
]
=
1
else
:
guide
,
state
=
self
.
grammar
vocab_mask
.
fill_
(
1
)
vocab_mask
[
guide
.
get_next_instruction
(
state
).
tokens
]
=
0
class
GrammarBackend
:
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
skip_tokenizer_init
=
False
,
whitespace_patterns
=
None
,
backend
=
None
,
allow_jump
=
False
,
):
self
.
executor
=
ThreadPoolExecutor
()
self
.
backend
=
backend
if
backend
==
"xgrammar"
:
self
.
grammar_cache
=
XGrammarBackend
(
tokenizer_path
=
tokenizer_path
,
tokenizer_args_dict
=
tokenizer_args_dict
,
skip_tokenizer_init
=
skip_tokenizer_init
,
whitespace_patterns
=
whitespace_patterns
,
)
self
.
jump_cache
=
XGrammarJumpCache
()
if
allow_jump
else
None
else
:
assert
backend
==
"outlines"
self
.
grammar_cache
=
OutlinesCache
(
tokenizer_path
=
tokenizer_path
,
tokenizer_args_dict
=
tokenizer_args_dict
,
skip_tokenizer_init
=
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
whitespace_patterns
,
)
self
.
jump_cache
=
OutlinesJumpCache
()
if
allow_jump
else
None
def
_query
(
self
,
key
:
Tuple
[
str
,
str
],
vocab_size
:
int
)
->
Grammar
:
if
isinstance
(
self
.
grammar_cache
,
XGrammarBackend
):
return
Grammar
(
self
.
grammar_cache
.
query
(
key
,
vocab_size
),
self
.
jump_cache
)
else
:
guide
,
regex
=
self
.
grammar_cache
.
query
(
key
)
jump_map
=
self
.
jump_cache
.
query
(
regex
)
return
Grammar
((
guide
,
0
),
jump_map
)
def
query
(
self
,
key
:
Tuple
[
str
,
str
],
vocab_size
:
int
)
->
Future
:
return
self
.
executor
.
submit
(
self
.
_query
,
key
,
vocab_size
)
def
reset
(
self
):
self
.
grammar_cache
.
reset
()
self
.
jump_cache
.
reset
()
python/sglang/srt/constrained/outlines_
c
ac
he
.py
→
python/sglang/srt/constrained/outlines_
b
ac
kend
.py
View file @
ba069a24
...
...
@@ -13,41 +13,139 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
"""Constrained decoding with outlines backend."""
import
json
import
logging
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
interegular
import
InvalidSyntax
,
parse_pattern
from
outlines.fsm.guide
import
RegexGuide
from
outlines.models.transformers
import
TransformerTokenizer
from
transformers
import
AutoTokenizer
from
pydantic
import
BaseModel
from
sglang.srt.constrained
import
build_regex_from_object
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
from
sglang.srt.constrained.outlines_jump_forward
import
(
OutlinesJumpForwardCache
,
OutlinesJumpForwardMap
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
OutlinesCache
(
BaseToolCache
):
try
:
from
outlines.fsm.json_schema
import
build_regex_from_object
except
ImportError
:
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
# which only accepts string schema as input.
from
outlines.fsm.json_schema
import
build_regex_from_schema
def
build_regex_from_object
(
object
:
Union
[
str
,
BaseModel
,
Dict
],
whitespace_pattern
:
Optional
[
str
]
=
None
):
if
isinstance
(
object
,
type
(
BaseModel
)):
schema
=
json
.
dumps
(
object
.
model_json_schema
())
elif
isinstance
(
object
,
Dict
):
schema
=
json
.
dumps
(
object
)
else
:
schema
=
object
return
build_regex_from_schema
(
schema
,
whitespace_pattern
)
class
OutlinesGrammar
:
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dic
t
,
enable
=
True
,
skip_tokenizer_init
=
False
,
constrained_json_whitespace_pattern
=
None
,
):
s
uper
().
__init__
(
enable
=
enable
)
guide
:
RegexGuide
,
state
:
in
t
,
jump_forward_map
:
Union
[
OutlinesJumpForwardMap
,
None
]
,
)
->
None
:
self
.
guide
=
guide
self
.
state
=
state
s
elf
.
jump_forward_map
=
jump_forward_map
if
(
skip_tokenizer_init
or
tokenizer_path
.
endswith
(
".json"
)
or
tokenizer_path
.
endswith
(
".model"
)
def
accept_token
(
self
,
token
:
int
):
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
]:
if
not
self
.
jump_forward_map
:
return
None
jump_forward_bytes
=
self
.
jump_forward_map
.
jump_forward_byte
(
self
.
state
)
if
jump_forward_bytes
is
None
or
len
(
jump_forward_bytes
)
<=
1
:
return
None
# preprocess the jump forward string
suffix_bytes
=
[]
continuation_range
=
range
(
0x80
,
0xC0
)
cur_state
=
self
.
state
while
(
len
(
jump_forward_bytes
)
and
jump_forward_bytes
[
0
][
0
]
in
continuation_range
):
# Do not support TiktokenTokenizer or SentencePieceTokenizer
return
# continuation bytes
byte_edge
=
jump_forward_bytes
.
pop
(
0
)
suffix_bytes
.
append
(
byte_edge
[
0
])
cur_state
=
byte_edge
[
1
]
suffix_tokens
=
[
f
"<0x
{
hex
(
b
)[
2
:].
upper
()
}
>"
for
b
in
suffix_bytes
]
suffix_ids
=
tokenizer
.
convert_tokens_to_ids
(
suffix_tokens
)
return
suffix_ids
,
cur_state
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
_
,
cur_state
=
helper
return
self
.
jump_forward_map
.
jump_forward_symbol
(
cur_state
)
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
self
.
state
=
next_state
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
):
vocab_mask
.
fill_
(
1
)
vocab_mask
[
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
]
=
0
class
OutlinesGrammarBackend
:
def
__init__
(
self
,
tokenizer
,
whitespace_patterns
:
bool
,
allow_jump_forward
:
bool
,
):
self
.
executor
=
ThreadPoolExecutor
()
self
.
grammar_cache
=
OutlinesCache
(
tokenizer
,
whitespace_pattern
=
whitespace_patterns
,
)
self
.
jump_forward_cache
=
(
OutlinesJumpForwardCache
()
if
allow_jump_forward
else
None
)
def
_query
(
self
,
key
:
Tuple
[
str
,
str
])
->
OutlinesGrammar
:
guide
,
regex
=
self
.
grammar_cache
.
query
(
key
)
jump_forward_map
=
(
self
.
jump_forward_cache
.
query
(
regex
)
if
self
.
jump_forward_cache
else
None
)
return
OutlinesGrammar
(
guide
,
0
,
jump_forward_map
)
def
query
(
self
,
key
:
Tuple
[
str
,
str
])
->
Future
:
return
self
.
executor
.
submit
(
self
.
_query
,
key
)
def
reset
(
self
):
self
.
grammar_cache
.
reset
()
if
self
.
jump_forward_cache
:
self
.
jump_forward_cache
.
reset
()
class
OutlinesCache
(
BaseToolCache
):
def
__init__
(
self
,
tokenizer
,
whitespace_pattern
=
None
,
):
super
().
__init__
(
enable
=
True
)
tokenizer_args_dict
.
setdefault
(
"padding_side"
,
"left"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_path
,
**
tokenizer_args_dict
)
try
:
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer
)
except
AttributeError
:
...
...
@@ -69,7 +167,7 @@ class OutlinesCache(BaseToolCache):
self
.
outlines_tokenizer
.
vocabulary
=
(
self
.
outlines_tokenizer
.
tokenizer
.
get_vocab
()
)
self
.
constrained_json_
whitespace_pattern
=
constrained_json_
whitespace_pattern
self
.
whitespace_pattern
=
whitespace_pattern
def
init_value
(
self
,
key
):
key_type
,
key_string
=
key
...
...
@@ -77,7 +175,7 @@ class OutlinesCache(BaseToolCache):
try
:
regex
=
build_regex_from_object
(
key_string
,
whitespace_pattern
=
self
.
constrained_json_
whitespace_pattern
,
whitespace_pattern
=
self
.
whitespace_pattern
,
)
except
NotImplementedError
as
e
:
logger
.
warning
(
...
...
@@ -93,4 +191,13 @@ class OutlinesCache(BaseToolCache):
except
InvalidSyntax
as
e
:
logger
.
warning
(
f
"skip invalid regex guide:
{
regex
=
}
,
{
e
=
}
"
)
return
None
,
regex
return
RegexGuide
(
regex
,
self
.
outlines_tokenizer
),
regex
ret
=
RegexGuide
(
regex
,
self
.
outlines_tokenizer
),
regex
return
ret
def
_query
(
self
,
key
:
Tuple
[
str
,
str
]):
guide
,
regex
=
self
.
grammar_cache
.
query
(
key
)
jump_forward_map
=
(
self
.
jump_forward_cache
.
query
(
regex
)
if
self
.
jump_forward_cache
else
None
)
return
OutlinesGrammar
(
guide
,
0
,
jump_forward_map
)
python/sglang/srt/constrained/outlines_jump_forward.py
View file @
ba069a24
...
...
@@ -164,7 +164,7 @@ class OutlinesJumpForwardMap:
)
class
OutlinesJumpCache
(
BaseToolCache
):
class
OutlinesJump
Forward
Cache
(
BaseToolCache
):
def
__init__
(
self
):
super
().
__init__
()
...
...
python/sglang/srt/constrained/xgrammar_
c
ac
he
.py
→
python/sglang/srt/constrained/xgrammar_
b
ac
kend
.py
View file @
ba069a24
...
...
@@ -3,7 +3,9 @@ Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -11,50 +13,98 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""C
ache for the compressed finite state machine
."""
"""C
onstrained decoding with xgrammar backend
."""
from
typing
import
Tuple
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
typing
import
List
,
Tuple
from
transformers
import
AutoTokenizer
import
torch
try
:
from
xgrammar
import
CachedGrammarCompiler
,
CompiledGrammar
,
GrammarMatcher
import_error
=
None
except
ImportError
as
e
:
import_error
=
e
class
Dummy
:
pass
GrammarMatcher
=
Dummy
CompiledGrammar
=
Dummy
CachedGrammarCompiler
=
Dummy
GrammarMatcher
=
CompiledGrammar
=
CachedGrammarCompiler
=
Dummy
MAX_ROLLBACK_TOKENS
=
10
class
XGrammarJumpCache
:
"""A dummy class."""
class
XGrammarGrammar
:
def
reset
(
self
):
pass
def
__init__
(
self
,
matcher
:
GrammarMatcher
,
vocab_size
:
int
)
->
None
:
self
.
matcher
=
matcher
self
.
vocab_size
=
vocab_size
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Tuple
[
List
[
int
],
str
]:
return
[],
self
.
matcher
.
find_jump_forward_string
()
class
XGrammarBackend
:
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
_
,
data
=
helper
return
data
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
k
=
0
for
i
,
old_id
in
enumerate
(
old_output_ids
):
if
old_id
==
new_output_ids
[
i
]:
k
=
i
+
1
else
:
break
# rollback to the last token that is the same
if
k
<
len
(
old_output_ids
):
self
.
matcher
.
rollback
(
len
(
old_output_ids
)
-
k
)
for
i
in
range
(
k
,
len
(
new_output_ids
)):
assert
self
.
matcher
.
accept_token
(
new_output_ids
[
i
])
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
):
# Note that this bitmask is a bitset, not bool
bitmask
=
self
.
matcher
.
get_next_token_bitmask
()
# Mask the tokens that are not allowed
vocab_mask
[
self
.
matcher
.
get_rejected_tokens_from_bitmask
(
bitmask
,
self
.
vocab_size
)
]
=
1
class
XGrammarGrammarBackend
:
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
skip_tokenizer_init
=
False
,
whitespace_patterns
=
None
,
tokenizer
,
vocab_size
:
int
,
):
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
if
skip_tokenizer_init
:
return
if
import_error
:
raise
import_error
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_path
,
**
tokenizer_args_dict
)
self
.
grammar_cache
:
CachedGrammarCompiler
=
CachedGrammarCompiler
(
tokenizer_or_vocab
=
tokenizer
)
self
.
executor
=
ThreadPoolExecutor
()
self
.
grammar_cache
=
XGrammarCache
(
tokenizer
,
vocab_size
)
self
.
vocab_size
=
vocab_size
def
_query
(
self
,
key
:
Tuple
[
str
,
str
])
->
XGrammarGrammar
:
return
XGrammarGrammar
(
self
.
grammar_cache
.
query
(
key
),
self
.
vocab_size
)
def
query
(
self
,
key
:
Tuple
[
str
,
str
])
->
Future
:
return
self
.
executor
.
submit
(
self
.
_query
,
key
)
def
reset
(
self
):
self
.
grammar_cache
.
reset
()
class
XGrammarCache
:
def
__init__
(
self
,
tokenizer
,
vocab_size
:
int
):
self
.
grammar_cache
=
CachedGrammarCompiler
(
tokenizer_or_vocab
=
tokenizer
)
self
.
vocab_size
=
vocab_size
def
get_context
(
self
,
key
:
Tuple
[
str
,
str
])
->
CompiledGrammar
:
key_type
,
key_string
=
key
...
...
@@ -65,10 +115,12 @@ class XGrammarBackend:
else
:
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
def
query
(
self
,
key
:
Tuple
[
str
,
str
]
,
vocab_size
:
int
)
->
GrammarMatcher
:
def
query
(
self
,
key
:
Tuple
[
str
,
str
])
->
GrammarMatcher
:
ctx
=
self
.
get_context
(
key
)
return
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
mask_vocab_size
=
vocab_size
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
mask_vocab_size
=
self
.
vocab_size
,
)
def
reset
(
self
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
ba069a24
...
...
@@ -37,7 +37,6 @@ import torch
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.grammar
import
Grammar
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
...
...
@@ -249,7 +248,7 @@ class Req:
self
.
embedding
=
None
# Constrained decoding
self
.
grammar
:
Optional
[
Grammar
]
=
None
self
.
grammar
=
None
# The number of cached tokens, that were already cached in the KV cache
self
.
cached_tokens
=
0
...
...
@@ -359,8 +358,6 @@ class Req:
return
def
jump_forward_and_retokenize
(
self
,
jump_forward_str
,
next_state
):
assert
self
.
grammar
is
not
None
and
self
.
tokenizer
is
not
None
if
self
.
origin_input_text
is
None
:
# Recovering text can only use unpadded ids
self
.
origin_input_text
=
self
.
tokenizer
.
decode
(
...
...
@@ -809,9 +806,10 @@ class ScheduleBatch:
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
grammar
is
not
None
:
jump_helper
=
req
.
grammar
.
try_jump
(
req
.
tokenizer
)
if
jump_helper
.
can_jump
():
suffix_ids
=
jump_helper
.
suffix_ids
jump_helper
=
req
.
grammar
.
try_jump_forward
(
req
.
tokenizer
)
if
jump_helper
:
suffix_ids
,
_
=
jump_helper
# Current ids, for cache and revert
cur_all_ids
=
tuple
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
cur_output_ids
=
req
.
output_ids
...
...
@@ -827,6 +825,8 @@ class ScheduleBatch:
next_state
,
)
=
req
.
grammar
.
jump_forward_str_state
(
jump_helper
)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str
=
new_text
+
jump_forward_str
if
not
req
.
jump_forward_and_retokenize
(
jump_forward_str
,
next_state
...
...
python/sglang/srt/managers/scheduler.py
View file @
ba069a24
...
...
@@ -21,6 +21,7 @@ import threading
import
time
import
warnings
from
collections
import
deque
from
concurrent
import
futures
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
...
...
@@ -29,7 +30,6 @@ import zmq
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.grammar
import
GrammarBackend
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
...
...
@@ -100,7 +100,7 @@ class Scheduler:
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
server_args
.
tp_size
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
disable_
regex_
jump_forward
=
server_args
.
disable_
regex_
jump_forward
self
.
disable_jump_forward
=
server_args
.
disable_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
enable_overlap
=
server_args
.
enable_overlap_schedule
...
...
@@ -234,22 +234,33 @@ class Scheduler:
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
# Init the grammar cache for constrained generation
self
.
grammar_cache
=
None
# Init the grammar backend for constrained generation
self
.
grammar_queue
:
List
[
Req
]
=
[]
if
not
server_args
.
skip_tokenizer_init
:
self
.
grammar_cache
=
GrammarBackend
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
whitespace_patterns
=
server_args
.
constrained_json_whitespace_pattern
,
backend
=
server_args
.
grammar_backend
,
allow_jump
=
not
server_args
.
disable_regex_jump_forward
,
)
if
server_args
.
grammar_backend
==
"outlines"
:
from
sglang.srt.constrained.outlines_backend
import
(
OutlinesGrammarBackend
,
)
self
.
grammar_backend
=
OutlinesGrammarBackend
(
self
.
tokenizer
,
whitespace_patterns
=
server_args
.
constrained_json_whitespace_pattern
,
allow_jump_forward
=
not
server_args
.
disable_jump_forward
,
)
elif
server_args
.
grammar_backend
==
"xgrammar"
:
from
sglang.srt.constrained.xgrammar_backend
import
(
XGrammarGrammarBackend
,
)
self
.
grammar_backend
=
XGrammarGrammarBackend
(
self
.
tokenizer
,
vocab_size
=
self
.
model_config
.
vocab_size
)
else
:
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
else
:
self
.
grammar_backend
=
None
# Init new token estimation
assert
(
...
...
@@ -461,15 +472,14 @@ class Scheduler:
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
):
assert
self
.
grammar_
c
ac
he
is
not
None
assert
self
.
grammar_
b
ac
kend
is
not
None
if
req
.
sampling_params
.
json_schema
is
not
None
:
req
.
grammar
=
self
.
grammar_
c
ac
he
.
query
(
req
.
grammar
=
self
.
grammar_
b
ac
kend
.
query
(
(
"json"
,
req
.
sampling_params
.
json_schema
),
self
.
model_config
.
vocab_size
,
)
elif
req
.
sampling_params
.
regex
is
not
None
:
req
.
grammar
=
self
.
grammar_
c
ac
he
.
query
(
(
"regex"
,
req
.
sampling_params
.
regex
)
,
self
.
model_config
.
vocab_size
req
.
grammar
=
self
.
grammar_
b
ac
kend
.
query
(
(
"regex"
,
req
.
sampling_params
.
regex
)
)
# Truncate prompts that are too long
...
...
@@ -638,14 +648,14 @@ class Scheduler:
return
self
.
running_batch
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
# Check if the grammar
queue
is ready
# Check if the grammar is ready
in the grammar queue
if
self
.
grammar_queue
:
new_grammar_queue
=
[]
for
req
in
self
.
grammar_queue
:
if
req
.
grammar
.
done
()
:
req
.
grammar
=
req
.
grammar
.
result
()
try
:
req
.
grammar
=
req
.
grammar
.
result
(
timeout
=
0.05
)
self
.
waiting_queue
.
append
(
req
)
e
lse
:
e
xcept
futures
.
_base
.
TimeoutError
:
new_grammar_queue
.
append
(
req
)
self
.
grammar_queue
=
new_grammar_queue
...
...
@@ -783,7 +793,7 @@ class Scheduler:
)
# Check for jump-forward
if
not
self
.
disable_
regex_
jump_forward
:
if
not
self
.
disable_jump_forward
:
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
pad_input_ids_func
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
...
...
@@ -1142,8 +1152,8 @@ class Scheduler:
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
if
self
.
grammar_
c
ac
he
is
not
None
:
self
.
grammar_
c
ac
he
.
reset
()
if
self
.
grammar_
b
ac
kend
is
not
None
:
self
.
grammar_
b
ac
kend
.
reset
()
# TODO(dark): reset the bnf cache
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
ba069a24
...
...
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, List, Optional
import
torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.constrained.grammar
import
Grammar
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
...
@@ -31,7 +30,7 @@ class SamplingBatchInfo:
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
grammars
:
Optional
[
List
[
Optional
[
Grammar
]]
]
=
None
grammars
:
Optional
[
List
]
=
None
# Penalizer
penalizer_orchestrator
:
Optional
[
penaltylib
.
BatchedPenalizerOrchestrator
]
=
None
...
...
@@ -146,7 +145,7 @@ class SamplingBatchInfo:
)
for
i
,
grammar
in
enumerate
(
self
.
grammars
):
if
grammar
is
not
None
:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
[
i
]
,
self
.
vocab_size
)
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
[
i
])
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
if
self
.
penalizer_orchestrator
:
...
...
python/sglang/srt/server_args.py
View file @
ba069a24
...
...
@@ -111,7 +111,7 @@ class ServerArgs:
disable_flashinfer
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_
regex_
jump_forward
:
bool
=
False
disable_jump_forward
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_cuda_graph_padding
:
bool
=
False
disable_disk_cache
:
bool
=
False
...
...
@@ -574,7 +574,7 @@ class ServerArgs:
type
=
str
,
choices
=
[
"xgrammar"
,
"outlines"
],
default
=
ServerArgs
.
grammar_backend
,
help
=
"Choose the backend for
constrain
ed decoding."
,
help
=
"Choose the backend for
grammar-guid
ed decoding."
,
)
# Optimization/debug options
...
...
@@ -594,9 +594,9 @@ class ServerArgs:
help
=
"Disable RadixAttention for prefix caching."
,
)
parser
.
add_argument
(
"--disable-
regex-
jump-forward"
,
"--disable-jump-forward"
,
action
=
"store_true"
,
help
=
"Disable
regex
jump-forward."
,
help
=
"Disable jump-forward
for grammar-guided decoding
."
,
)
parser
.
add_argument
(
"--disable-cuda-graph"
,
...
...
@@ -616,7 +616,6 @@ class ServerArgs:
parser
.
add_argument
(
"--disable-custom-all-reduce"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Disable the custom all-reduce kernel and fall back to NCCL."
,
)
parser
.
add_argument
(
...
...
@@ -688,7 +687,6 @@ class ServerArgs:
)
parser
.
add_argument
(
"--delete-ckpt-after-loading"
,
default
=
ServerArgs
.
delete_ckpt_after_loading
,
action
=
"store_true"
,
help
=
"Delete the model checkpoint after loading the model."
,
)
...
...
test/srt/test_json_constrained.py
View file @
ba069a24
...
...
@@ -61,18 +61,27 @@ class TestJSONConstrained(unittest.TestCase):
"logprob_start_len"
:
0
,
},
)
print
(
json
.
dumps
(
response
.
json
()))
ret
=
response
.
json
()
print
(
json
.
dumps
(
ret
))
print
(
"="
*
100
)
if
not
json_schema
:
return
# Make sure the json output is valid
try
:
js_obj
=
json
.
loads
(
re
sponse
.
json
()
[
"text"
])
js_obj
=
json
.
loads
(
re
t
[
"text"
])
except
(
TypeError
,
json
.
decoder
.
JSONDecodeError
):
raise
assert
isinstance
(
js_obj
[
"name"
],
str
)
assert
isinstance
(
js_obj
[
"population"
],
int
)
self
.
assertIsInstance
(
js_obj
[
"name"
],
str
)
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
# Make sure jump forward is triggered
self
.
assertGreater
(
ret
[
"meta_info"
][
"completion_tokens"
],
ret
[
"meta_info"
][
"completion_tokens_wo_jump_forward"
],
)
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
...
...
@@ -100,8 +109,9 @@ class TestJSONConstrained(unittest.TestCase):
except
(
TypeError
,
json
.
decoder
.
JSONDecodeError
):
print
(
"JSONDecodeError"
,
text
)
raise
assert
isinstance
(
js_obj
[
"name"
],
str
),
f
"
{
js_obj
=
}
"
assert
isinstance
(
js_obj
[
"population"
],
int
),
f
"
{
js_obj
=
}
"
self
.
assertIsInstance
(
js_obj
[
"name"
],
str
)
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
def
test_mix_json_and_other
(
self
):
json_schemas
=
[
None
,
None
,
self
.
json_schema
,
self
.
json_schema
]
*
10
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment