Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
125b1199
Unverified
Commit
125b1199
authored
Nov 13, 2024
by
DarkSharpness
Committed by
GitHub
Nov 12, 2024
Browse files
support parallel grammar preprocessing (#1996)
Co-authored-by:
Lianmin Zheng
<
lianminzheng@gmail.com
>
parent
eff468dd
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
159 additions
and
141 deletions
+159
-141
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+0
-39
python/sglang/srt/constrained/base_tool_cache.py
python/sglang/srt/constrained/base_tool_cache.py
+62
-23
python/sglang/srt/constrained/grammar.py
python/sglang/srt/constrained/grammar.py
+36
-44
python/sglang/srt/constrained/outlines_cache.py
python/sglang/srt/constrained/outlines_cache.py
+5
-4
python/sglang/srt/constrained/outlines_jump_forward.py
python/sglang/srt/constrained/outlines_jump_forward.py
+7
-12
python/sglang/srt/constrained/xgrammar_cache.py
python/sglang/srt/constrained/xgrammar_cache.py
+25
-11
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+20
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-1
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+2
-2
No files found.
python/sglang/srt/constrained/__init__.py
View file @
125b1199
...
@@ -13,25 +13,11 @@ See the License for the specific language governing permissions and
...
@@ -13,25 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""For constrained decoding."""
import
json
import
json
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
try
:
from
outlines.caching
import
cache
as
disk_cache
from
outlines.caching
import
disable_cache
from
outlines.fsm.guide
import
RegexGuide
from
outlines.fsm.regex
import
FSMInfo
,
make_byte_level_fsm
,
make_deterministic_fsm
from
outlines.models.transformers
import
TransformerTokenizer
except
ImportError
as
e
:
print
(
f
'
\n
Error:
{
e
}
. Please install a new version of outlines by `pip install "outlines>=0.0.44"`
\n
'
)
raise
try
:
try
:
from
outlines.fsm.json_schema
import
build_regex_from_object
from
outlines.fsm.json_schema
import
build_regex_from_object
except
ImportError
:
except
ImportError
:
...
@@ -51,31 +37,6 @@ except ImportError:
...
@@ -51,31 +37,6 @@ except ImportError:
return
build_regex_from_schema
(
schema
,
whitespace_pattern
)
return
build_regex_from_schema
(
schema
,
whitespace_pattern
)
try
:
from
xgrammar
import
(
GrammarMatcher
,
GrammarMatcherInitContext
,
GrammarMatcherInitContextCache
,
)
except
ImportError
as
e
:
class
Dummy
:
pass
GrammarMatcher
=
Dummy
GrammarMatcherInitContext
=
Dummy
GrammarMatcherInitContextCache
=
Dummy
__all__
=
[
__all__
=
[
"RegexGuide"
,
"FSMInfo"
,
"make_deterministic_fsm"
,
"build_regex_from_object"
,
"build_regex_from_object"
,
"TransformerTokenizer"
,
"disk_cache"
,
"disable_cache"
,
"make_byte_level_fsm"
,
"GrammarMatcher"
,
"GrammarMatcherInitContext"
,
"GrammarMatcherInitContextCache"
,
]
]
python/sglang/srt/constrained/base_tool_cache.py
View file @
125b1199
...
@@ -13,25 +13,47 @@ See the License for the specific language governing permissions and
...
@@ -13,25 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""Base
tool
cache for constrained decoding tools."""
"""Base cache
class
for constrained decoding tools."""
import
time
import
time
from
dataclasses
import
dataclass
from
threading
import
Event
,
Lock
from
typing
import
Any
,
Dict
,
Tuple
@
dataclass
class
MapEntry
:
event
:
Event
value
:
Any
def
__iter__
(
self
):
return
iter
((
self
.
event
,
self
.
value
))
class
BaseToolCache
:
class
BaseToolCache
:
def
__init__
(
self
,
enable
=
True
):
def
__init__
(
self
,
enable
=
True
):
self
.
enable
=
enable
self
.
enable
:
bool
=
enable
self
.
cache
:
Dict
[
str
,
MapEntry
]
=
{}
self
.
metrics
:
Dict
[
str
,
Any
]
=
{}
self
.
lock_cache
:
Lock
=
Lock
()
self
.
lock_metrics
:
Lock
=
Lock
()
self
.
reset
()
self
.
reset
()
def
reset
(
self
):
def
reset
(
self
):
self
.
cache
=
{}
with
self
.
lock_cache
:
self
.
metrics
=
{
"total"
:
0
,
"hit"
:
0
,
"avg_init_time"
:
0
}
self
.
cache
=
{}
with
self
.
lock_metrics
:
self
.
metrics
=
{
"total"
:
0
,
"hit"
:
0
,
"avg_init_time"
:
0
}
def
query
(
self
,
key
):
def
_init_with_timer
(
self
,
key
)
->
Tuple
[
Any
,
float
]:
def
_init_with_timer
(
key
):
start
=
time
.
monotonic
()
start
=
time
.
monotonic
()
val
=
self
.
init_value
(
key
)
val
=
self
.
init_value
(
key
)
init_time
=
time
.
monotonic
()
-
start
init_time
=
time
.
monotonic
()
-
start
return
val
,
init_time
def
update_time
(
self
,
init_time
):
with
self
.
lock_metrics
:
curr_total
=
self
.
metrics
[
"total"
]
curr_total
=
self
.
metrics
[
"total"
]
new_total
=
curr_total
+
1
new_total
=
curr_total
+
1
...
@@ -39,27 +61,44 @@ class BaseToolCache:
...
@@ -39,27 +61,44 @@ class BaseToolCache:
self
.
metrics
[
"avg_init_time"
]
=
(
init_time
/
new_total
)
+
(
self
.
metrics
[
"avg_init_time"
]
=
(
init_time
/
new_total
)
+
(
curr_total
/
new_total
curr_total
/
new_total
)
*
self
.
metrics
[
"avg_init_time"
]
)
*
self
.
metrics
[
"avg_init_time"
]
return
val
if
key
in
self
.
cache
:
def
query
(
self
,
key
):
self
.
metrics
[
"hit"
]
+=
1
if
not
self
.
enable
:
val
=
self
.
cache
[
key
]
value
,
init_time
=
self
.
_init_with_timer
(
key
)
else
:
self
.
update_time
(
init_time
)
# Cache miss or disabled.
return
value
val
=
_init_with_timer
(
key
)
with
self
.
lock_cache
:
if
key
in
self
.
cache
:
entry
=
self
.
cache
[
key
]
cache_hit
=
True
else
:
entry
=
MapEntry
(
Event
(),
None
)
self
.
cache
[
key
]
=
entry
cache_hit
=
False
if
self
.
enable
:
with
self
.
lock_metrics
:
self
.
metrics
[
"total"
]
+=
1
self
.
metrics
[
"total"
]
+=
1
self
.
cache
[
key
]
=
val
if
cache_hit
:
return
val
self
.
metrics
[
"hit"
]
+=
1
if
cache_hit
:
entry
.
event
.
wait
()
else
:
entry
.
value
,
init_time
=
self
.
_init_with_timer
(
key
)
self
.
update_time
(
init_time
)
entry
.
event
.
set
()
return
entry
.
value
def
init_value
(
self
,
key
):
def
init_value
(
self
,
key
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_cache_hit_rate
(
self
):
def
get_cache_hit_rate
(
self
):
if
self
.
metrics
[
"total"
]
==
0
:
with
self
.
lock_metrics
:
return
0
if
self
.
metrics
[
"total"
]
==
0
:
return
self
.
metrics
[
"hit"
]
/
self
.
metrics
[
"total"
]
return
0
return
self
.
metrics
[
"hit"
]
/
self
.
metrics
[
"total"
]
def
get_avg_init_time
(
self
):
def
get_avg_init_time
(
self
):
return
self
.
metrics
[
"avg_init_time"
]
with
self
.
lock_metrics
:
return
self
.
metrics
[
"avg_init_time"
]
python/sglang/srt/constrained/grammar.py
View file @
125b1199
...
@@ -13,50 +13,44 @@ limitations under the License.
...
@@ -13,50 +13,44 @@ limitations under the License.
"""Cache for the compressed finite state machine."""
"""Cache for the compressed finite state machine."""
import
logging
import
logging
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
typing
import
List
,
Tuple
,
Union
import
torch
import
torch
from
sglang.srt.constrained
import
GrammarMatcher
,
RegexGuide
from
sglang.srt.constrained.outlines_cache
import
OutlinesCache
,
RegexGuide
from
sglang.srt.constrained.bnf_cache
import
BNFCache
from
sglang.srt.constrained.outlines_jump_forward
import
(
from
sglang.srt.constrained.fsm_cache
import
FSMCache
OutlinesJumpCache
,
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
,
JumpForwardMap
OutlinesJumpForwardMap
,
)
# from sglang.srt.managers.schedule_batch import Req
from
sglang.srt.constrained.xgrammar_cache
import
(
GrammarMatcher
,
XGrammarBackend
,
XGrammarJumpCache
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
class
XGrammarJump
:
pass
class
JumpHelper
:
class
JumpHelper
:
data
:
Union
[
List
,
str
]
state
:
int
suffix_ids
:
List
[
int
]
def
__init__
(
def
__init__
(
self
,
data
:
Union
[
List
,
str
]
=
""
,
state
:
int
=
-
1
,
suffix_ids
=
[]
self
,
data
:
Union
[
List
,
str
]
=
""
,
state
:
int
=
-
1
,
suffix_ids
=
[]
)
->
None
:
)
->
None
:
self
.
data
=
data
self
.
data
:
Union
[
List
,
str
]
=
data
self
.
state
=
state
self
.
state
:
int
=
state
self
.
suffix_ids
=
suffix_ids
self
.
suffix_ids
:
List
[
int
]
=
suffix_ids
def
can_jump
(
self
):
def
can_jump
(
self
):
return
len
(
self
.
data
)
>
0
return
len
(
self
.
data
)
>
0
class
Grammar
:
class
Grammar
:
grammar
:
Union
[
GrammarMatcher
,
Tuple
[
RegexGuide
,
int
]]
jump_map
:
Union
[
XGrammarJump
,
JumpForwardMap
,
None
]
def
__init__
(
def
__init__
(
self
,
self
,
grammar
:
Union
[
GrammarMatcher
,
Tuple
[
RegexGuide
,
int
]],
grammar
:
Union
[
GrammarMatcher
,
Tuple
[
RegexGuide
,
int
]],
jump_map
:
Union
[
XGrammarJump
,
JumpForwardMap
,
None
],
jump_map
:
Union
[
XGrammarJump
Cache
,
Outlines
JumpForwardMap
,
None
],
)
->
None
:
)
->
None
:
self
.
grammar
=
grammar
self
.
grammar
=
grammar
self
.
jump_map
=
jump_map
self
.
jump_map
=
jump_map
...
@@ -69,10 +63,10 @@ class Grammar:
...
@@ -69,10 +63,10 @@ class Grammar:
self
.
grammar
=
guide
,
guide
.
get_next_state
(
state
,
token
)
self
.
grammar
=
guide
,
guide
.
get_next_state
(
state
,
token
)
def
try_jump
(
self
,
tokenizer
)
->
JumpHelper
:
def
try_jump
(
self
,
tokenizer
)
->
JumpHelper
:
if
isinstance
(
self
.
jump_map
,
XGrammarJump
):
if
isinstance
(
self
.
jump_map
,
XGrammarJump
Cache
):
assert
isinstance
(
self
.
grammar
,
GrammarMatcher
)
assert
isinstance
(
self
.
grammar
,
GrammarMatcher
)
return
JumpHelper
(
self
.
grammar
.
find_jump_forward_string
())
return
JumpHelper
(
self
.
grammar
.
find_jump_forward_string
())
elif
isinstance
(
self
.
jump_map
,
JumpForwardMap
):
elif
isinstance
(
self
.
jump_map
,
Outlines
JumpForwardMap
):
assert
isinstance
(
self
.
grammar
,
Tuple
)
assert
isinstance
(
self
.
grammar
,
Tuple
)
_
,
state
=
self
.
grammar
_
,
state
=
self
.
grammar
...
@@ -103,7 +97,7 @@ class Grammar:
...
@@ -103,7 +97,7 @@ class Grammar:
if
isinstance
(
helper
.
data
,
str
):
if
isinstance
(
helper
.
data
,
str
):
return
helper
.
data
,
-
1
return
helper
.
data
,
-
1
else
:
else
:
assert
isinstance
(
self
.
jump_map
,
JumpForwardMap
)
assert
isinstance
(
self
.
jump_map
,
Outlines
JumpForwardMap
)
return
self
.
jump_map
.
jump_forward_symbol
(
helper
.
state
)
return
self
.
jump_map
.
jump_forward_symbol
(
helper
.
state
)
def
jump_and_retokenize
(
def
jump_and_retokenize
(
...
@@ -129,7 +123,7 @@ class Grammar:
...
@@ -129,7 +123,7 @@ class Grammar:
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
vocab_size
:
int
):
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
vocab_size
:
int
):
if
isinstance
(
self
.
grammar
,
GrammarMatcher
):
if
isinstance
(
self
.
grammar
,
GrammarMatcher
):
# Note that this bitmask is a bitset, not bool
# Note that this bitmask is a bitset, not bool
bitmask
=
self
.
grammar
.
find
_next_token_bitmask
()
bitmask
=
self
.
grammar
.
get
_next_token_bitmask
()
# Mask the tokens that are not allowed
# Mask the tokens that are not allowed
vocab_mask
[
vocab_mask
[
self
.
grammar
.
get_rejected_tokens_from_bitmask
(
bitmask
,
vocab_size
)
self
.
grammar
.
get_rejected_tokens_from_bitmask
(
bitmask
,
vocab_size
)
...
@@ -140,9 +134,7 @@ class Grammar:
...
@@ -140,9 +134,7 @@ class Grammar:
vocab_mask
[
guide
.
get_next_instruction
(
state
).
tokens
]
=
0
vocab_mask
[
guide
.
get_next_instruction
(
state
).
tokens
]
=
0
class
GrammarCache
:
class
GrammarBackend
:
grammar_cache
:
Union
[
BNFCache
,
FSMCache
]
jump_cache
:
Union
[
XGrammarJump
,
JumpForwardCache
,
None
]
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -153,38 +145,38 @@ class GrammarCache:
...
@@ -153,38 +145,38 @@ class GrammarCache:
backend
=
None
,
backend
=
None
,
allow_jump
=
False
,
allow_jump
=
False
,
):
):
self
.
executor
=
ThreadPoolExecutor
()
self
.
backend
=
backend
if
backend
==
"xgrammar"
:
if
backend
==
"xgrammar"
:
self
.
grammar_cache
=
BNFCache
(
self
.
grammar_cache
=
XGrammarBackend
(
tokenizer_path
=
tokenizer_path
,
tokenizer_path
=
tokenizer_path
,
tokenizer_args_dict
=
tokenizer_args_dict
,
tokenizer_args_dict
=
tokenizer_args_dict
,
skip_tokenizer_init
=
skip_tokenizer_init
,
skip_tokenizer_init
=
skip_tokenizer_init
,
whitespace_patterns
=
whitespace_patterns
,
whitespace_patterns
=
whitespace_patterns
,
)
)
self
.
jump_cache
=
XGrammarJump
()
if
allow_jump
else
None
self
.
jump_cache
=
XGrammarJump
Cache
()
if
allow_jump
else
None
else
:
else
:
assert
backend
==
"outlines"
assert
backend
==
"outlines"
self
.
grammar_cache
=
FSM
Cache
(
self
.
grammar_cache
=
Outlines
Cache
(
tokenizer_path
=
tokenizer_path
,
tokenizer_path
=
tokenizer_path
,
tokenizer_args_dict
=
tokenizer_args_dict
,
tokenizer_args_dict
=
tokenizer_args_dict
,
skip_tokenizer_init
=
skip_tokenizer_init
,
skip_tokenizer_init
=
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
whitespace_patterns
,
constrained_json_whitespace_pattern
=
whitespace_patterns
,
enable
=
True
,
)
)
self
.
jump_cache
=
JumpForward
Cache
()
if
allow_jump
else
None
self
.
jump_cache
=
OutlinesJump
Cache
()
if
allow_jump
else
None
def
query
(
self
,
key
:
Tuple
[
str
,
str
],
vocab_size
:
int
)
->
Grammar
:
def
_query
(
self
,
key
:
Tuple
[
str
,
str
],
vocab_size
:
int
)
->
Grammar
:
if
isinstance
(
self
.
grammar_cache
,
BNFCache
):
if
isinstance
(
self
.
grammar_cache
,
XGrammarBackend
):
assert
not
isinstance
(
self
.
jump_cache
,
JumpForwardCache
)
return
Grammar
(
self
.
grammar_cache
.
query
(
key
,
vocab_size
),
self
.
jump_cache
)
return
Grammar
(
self
.
grammar_cache
.
query
(
key
,
vocab_size
),
self
.
jump_cache
)
else
:
else
:
jump_map
=
None
guide
,
regex
=
self
.
grammar_cache
.
query
(
key
)
guide
,
regex
=
self
.
grammar_cache
.
query
(
key
)
if
isinstance
(
self
.
jump_cache
,
JumpForwardCache
):
jump_map
=
self
.
jump_cache
.
query
(
regex
)
jump_map
=
self
.
jump_cache
.
query
(
regex
)
return
Grammar
((
guide
,
0
),
jump_map
)
return
Grammar
((
guide
,
0
),
jump_map
)
def
query
(
self
,
key
:
Tuple
[
str
,
str
],
vocab_size
:
int
)
->
Future
:
return
self
.
executor
.
submit
(
self
.
_query
,
key
,
vocab_size
)
def
reset
(
self
):
def
reset
(
self
):
if
isinstance
(
self
.
grammar_cache
,
FSMCache
):
self
.
grammar_cache
.
reset
()
self
.
grammar_cache
.
reset
()
self
.
jump_cache
.
reset
()
if
isinstance
(
self
.
jump_cache
,
JumpForwardCache
):
self
.
jump_cache
.
reset
()
python/sglang/srt/constrained/
fsm
_cache.py
→
python/sglang/srt/constrained/
outlines
_cache.py
View file @
125b1199
...
@@ -17,16 +17,17 @@ limitations under the License.
...
@@ -17,16 +17,17 @@ limitations under the License.
import
logging
import
logging
from
interegular
import
InvalidSyntax
,
parse_pattern
from
interegular
import
InvalidSyntax
,
parse_pattern
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
outlines.fsm.guide
import
RegexGuide
from
outlines.models.transformers
import
TransformerTokenizer
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
sglang.srt.constrained
import
RegexGuide
,
TransformerTokenizer
from
sglang.srt.constrained
import
build_regex_from_object
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
FSM
Cache
(
BaseToolCache
):
class
Outlines
Cache
(
BaseToolCache
):
def
__init__
(
def
__init__
(
self
,
self
,
tokenizer_path
,
tokenizer_path
,
...
@@ -74,7 +75,7 @@ class FSMCache(BaseToolCache):
...
@@ -74,7 +75,7 @@ class FSMCache(BaseToolCache):
key_type
,
key_string
=
key
key_type
,
key_string
=
key
if
key_type
==
"json"
:
if
key_type
==
"json"
:
try
:
try
:
regex
=
build_regex_from_
schema
(
regex
=
build_regex_from_
object
(
key_string
,
key_string
,
whitespace_pattern
=
self
.
constrained_json_whitespace_pattern
,
whitespace_pattern
=
self
.
constrained_json_whitespace_pattern
,
)
)
...
...
python/sglang/srt/constrained/jump_forward.py
→
python/sglang/srt/constrained/
outlines_
jump_forward.py
View file @
125b1199
...
@@ -14,7 +14,7 @@ limitations under the License.
...
@@ -14,7 +14,7 @@ limitations under the License.
"""
"""
"""
"""
Faster constrained decoding.
Faster constrained decoding
with jump forward decoding / compressed finite state machine
.
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
"""
"""
...
@@ -23,15 +23,10 @@ import logging
...
@@ -23,15 +23,10 @@ import logging
from
collections
import
defaultdict
from
collections
import
defaultdict
import
interegular
import
interegular
import
outlines.caching
from
interegular
import
InvalidSyntax
from
interegular
import
InvalidSyntax
from
outlines.caching
import
cache
as
disk_cache
from
outlines.fsm.regex
import
FSMInfo
,
make_byte_level_fsm
,
make_deterministic_fsm
from
sglang.srt.constrained
import
(
FSMInfo
,
disk_cache
,
make_byte_level_fsm
,
make_deterministic_fsm
,
)
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
...
@@ -47,7 +42,7 @@ class JumpEdge:
...
@@ -47,7 +42,7 @@ class JumpEdge:
byte_next_state
:
int
=
None
byte_next_state
:
int
=
None
class
JumpForwardMap
:
class
Outlines
JumpForwardMap
:
def
__init__
(
self
,
regex_string
):
def
__init__
(
self
,
regex_string
):
@
disk_cache
()
@
disk_cache
()
def
_init_state_to_jump_forward
(
regex_string
):
def
_init_state_to_jump_forward
(
regex_string
):
...
@@ -169,12 +164,12 @@ class JumpForwardMap:
...
@@ -169,12 +164,12 @@ class JumpForwardMap:
)
)
class
JumpForward
Cache
(
BaseToolCache
):
class
OutlinesJump
Cache
(
BaseToolCache
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
def
init_value
(
self
,
regex
):
def
init_value
(
self
,
regex
):
forward_map
=
JumpForwardMap
(
regex
)
forward_map
=
Outlines
JumpForwardMap
(
regex
)
if
forward_map
.
state_to_jump_forward
:
if
forward_map
.
state_to_jump_forward
:
return
forward_map
return
forward_map
else
:
else
:
...
@@ -182,7 +177,7 @@ class JumpForwardCache(BaseToolCache):
...
@@ -182,7 +177,7 @@ class JumpForwardCache(BaseToolCache):
def
test_main
(
regex_string
):
def
test_main
(
regex_string
):
jump_forward_map
=
JumpForwardMap
(
regex_string
)
jump_forward_map
=
Outlines
JumpForwardMap
(
regex_string
)
for
state
,
e
in
jump_forward_map
.
state_to_jump_forward
.
items
():
for
state
,
e
in
jump_forward_map
.
state_to_jump_forward
.
items
():
if
e
.
symbol
is
not
None
:
if
e
.
symbol
is
not
None
:
jump_forward_str
,
next_state
=
jump_forward_map
.
jump_forward_symbol
(
state
)
jump_forward_str
,
next_state
=
jump_forward_map
.
jump_forward_symbol
(
state
)
...
...
python/sglang/srt/constrained/
bnf
_cache.py
→
python/sglang/srt/constrained/
xgrammar
_cache.py
View file @
125b1199
...
@@ -17,18 +17,29 @@ from typing import Tuple
...
@@ -17,18 +17,29 @@ from typing import Tuple
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
sglang.srt.constrained
import
(
try
:
GrammarMatcher
,
from
xgrammar
import
CachedGrammarCompiler
,
CompiledGrammar
,
GrammarMatcher
GrammarMatcherInitContext
,
except
ImportError
as
e
:
GrammarMatcherInitContextCache
,
)
class
Dummy
:
pass
GrammarMatcher
=
Dummy
CompiledGrammar
=
Dummy
CachedGrammarCompiler
=
Dummy
MAX_ROLLBACK_TOKENS
=
10
MAX_ROLLBACK_TOKENS
=
10
class
BNF
Cache
:
class
XGrammarJump
Cache
:
grammar_cache
:
GrammarMatcherInitContextCache
"""A dummy class."""
def
reset
(
self
):
pass
class
XGrammarBackend
:
def
__init__
(
def
__init__
(
self
,
self
,
tokenizer_path
,
tokenizer_path
,
...
@@ -41,16 +52,16 @@ class BNFCache:
...
@@ -41,16 +52,16 @@ class BNFCache:
return
return
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_path
,
**
tokenizer_args_dict
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_path
,
**
tokenizer_args_dict
)
self
.
grammar_cache
=
GrammarMatcherInitContextCache
(
self
.
grammar_cache
:
CachedGrammarCompiler
=
CachedGrammarCompiler
(
tokenizer_or_vocab
=
tokenizer
tokenizer_or_vocab
=
tokenizer
)
)
def
get_context
(
self
,
key
:
Tuple
[
str
,
str
])
->
GrammarMatcherInitContext
:
def
get_context
(
self
,
key
:
Tuple
[
str
,
str
])
->
CompiledGrammar
:
key_type
,
key_string
=
key
key_type
,
key_string
=
key
if
key_type
==
"json"
:
if
key_type
==
"json"
:
return
self
.
grammar_cache
.
get_
init_context
_for_json_schema
(
key_string
)
return
self
.
grammar_cache
.
get_
compiled_grammar
_for_json_schema
(
key_string
)
elif
key_type
==
"regex"
:
elif
key_type
==
"regex"
:
raise
ValueError
(
f
"regex hasn't been supported by xgrammar yet"
)
raise
ValueError
(
"regex hasn't been supported by xgrammar yet"
)
else
:
else
:
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
...
@@ -59,3 +70,6 @@ class BNFCache:
...
@@ -59,3 +70,6 @@ class BNFCache:
return
GrammarMatcher
(
return
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
mask_vocab_size
=
vocab_size
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
mask_vocab_size
=
vocab_size
)
)
def
reset
(
self
):
self
.
grammar_cache
.
clear
()
python/sglang/srt/managers/scheduler.py
View file @
125b1199
...
@@ -29,7 +29,7 @@ import zmq
...
@@ -29,7 +29,7 @@ import zmq
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.grammar
import
Grammar
C
ac
he
from
sglang.srt.constrained.grammar
import
Grammar
B
ac
kend
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
...
@@ -234,11 +234,12 @@ class Scheduler:
...
@@ -234,11 +234,12 @@ class Scheduler:
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
)
# Init the
FSM
cache for constrained generation
# Init the
grammar
cache for constrained generation
self
.
grammar_cache
=
None
self
.
grammar_cache
=
None
self
.
grammar_queue
:
List
[
Req
]
=
[]
if
not
server_args
.
skip_tokenizer_init
:
if
not
server_args
.
skip_tokenizer_init
:
self
.
grammar_cache
=
Grammar
C
ac
he
(
self
.
grammar_cache
=
Grammar
B
ac
kend
(
server_args
.
tokenizer_path
,
server_args
.
tokenizer_path
,
{
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
...
@@ -455,7 +456,7 @@ class Scheduler:
...
@@ -455,7 +456,7 @@ class Scheduler:
# By default, only return the logprobs for output tokens
# By default, only return the logprobs for output tokens
req
.
logprob_start_len
=
len
(
recv_req
.
input_ids
)
-
1
req
.
logprob_start_len
=
len
(
recv_req
.
input_ids
)
-
1
# Init
regex FSM or BNF
# Init
grammar cache for this request
if
(
if
(
req
.
sampling_params
.
json_schema
is
not
None
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
...
@@ -488,7 +489,10 @@ class Scheduler:
...
@@ -488,7 +489,10 @@ class Scheduler:
self
.
max_req_len
-
len
(
req
.
origin_input_ids
)
-
1
,
self
.
max_req_len
-
len
(
req
.
origin_input_ids
)
-
1
,
)
)
self
.
waiting_queue
.
append
(
req
)
if
req
.
grammar
is
not
None
:
self
.
grammar_queue
.
append
(
req
)
else
:
self
.
waiting_queue
.
append
(
req
)
def
handle_embedding_request
(
def
handle_embedding_request
(
self
,
self
,
...
@@ -634,6 +638,17 @@ class Scheduler:
...
@@ -634,6 +638,17 @@ class Scheduler:
return
self
.
running_batch
return
self
.
running_batch
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
# Check if the grammar queue is ready
if
self
.
grammar_queue
:
new_grammar_queue
=
[]
for
req
in
self
.
grammar_queue
:
if
req
.
grammar
.
done
():
req
.
grammar
=
req
.
grammar
.
result
()
self
.
waiting_queue
.
append
(
req
)
else
:
new_grammar_queue
.
append
(
req
)
self
.
grammar_queue
=
new_grammar_queue
# Handle the cases where prefill is not allowed
# Handle the cases where prefill is not allowed
if
(
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
125b1199
...
@@ -39,7 +39,6 @@ from vllm.model_executor.model_loader import get_model
...
@@ -39,7 +39,6 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.layers.attention.double_sparsity_backend
import
DoubleSparseAttnBackend
from
sglang.srt.layers.attention.double_sparsity_backend
import
DoubleSparseAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
...
@@ -129,6 +128,8 @@ class ModelRunner:
...
@@ -129,6 +128,8 @@ class ModelRunner:
if
server_args
.
show_time_cost
:
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
enable_show_time_cost
()
if
server_args
.
disable_disk_cache
:
if
server_args
.
disable_disk_cache
:
from
outlines.caching
import
disable_cache
disable_cache
()
disable_cache
()
global_server_args_dict
.
update
(
global_server_args_dict
.
update
(
...
...
test/srt/test_json_constrained.py
View file @
125b1199
...
@@ -100,8 +100,8 @@ class TestJSONConstrained(unittest.TestCase):
...
@@ -100,8 +100,8 @@ class TestJSONConstrained(unittest.TestCase):
except
(
TypeError
,
json
.
decoder
.
JSONDecodeError
):
except
(
TypeError
,
json
.
decoder
.
JSONDecodeError
):
print
(
"JSONDecodeError"
,
text
)
print
(
"JSONDecodeError"
,
text
)
raise
raise
assert
isinstance
(
js_obj
[
"name"
],
str
)
assert
isinstance
(
js_obj
[
"name"
],
str
)
,
f
"
{
js_obj
=
}
"
assert
isinstance
(
js_obj
[
"population"
],
int
)
assert
isinstance
(
js_obj
[
"population"
],
int
)
,
f
"
{
js_obj
=
}
"
def
test_mix_json_and_other
(
self
):
def
test_mix_json_and_other
(
self
):
json_schemas
=
[
None
,
None
,
self
.
json_schema
,
self
.
json_schema
]
*
10
json_schemas
=
[
None
,
None
,
self
.
json_schema
,
self
.
json_schema
]
*
10
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment