Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b77a02cd
"tests/python/vscode:/vscode.git/clone" did not exist on "dadce86a782f4527978f2efb349ae65087a77f08"
Unverified
Commit
b77a02cd
authored
Oct 26, 2024
by
DarkSharpness
Committed by
GitHub
Oct 25, 2024
Browse files
[Performance] Support both xgrammar and outlines for constrained decoding (#1752)
parent
30643fed
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
325 additions
and
77 deletions
+325
-77
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+18
-0
python/sglang/srt/constrained/bnf_cache.py
python/sglang/srt/constrained/bnf_cache.py
+61
-0
python/sglang/srt/constrained/grammar.py
python/sglang/srt/constrained/grammar.py
+190
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+20
-41
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+21
-23
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+7
-13
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
No files found.
python/sglang/srt/constrained/__init__.py
View file @
b77a02cd
...
@@ -51,6 +51,21 @@ except ImportError:
...
@@ -51,6 +51,21 @@ except ImportError:
return
build_regex_from_schema
(
schema
,
whitespace_pattern
)
return
build_regex_from_schema
(
schema
,
whitespace_pattern
)
try
:
from
xgrammar
import
(
GrammarMatcher
,
GrammarMatcherInitContext
,
GrammarMatcherInitContextCache
,
)
except
ImportError
as
e
:
class
Dummy
:
pass
GrammarMatcher
=
Dummy
GrammarMatcherInitContext
=
Dummy
GrammarMatcherInitContextCache
=
Dummy
__all__
=
[
__all__
=
[
"RegexGuide"
,
"RegexGuide"
,
"FSMInfo"
,
"FSMInfo"
,
...
@@ -60,4 +75,7 @@ __all__ = [
...
@@ -60,4 +75,7 @@ __all__ = [
"disk_cache"
,
"disk_cache"
,
"disable_cache"
,
"disable_cache"
,
"make_byte_level_fsm"
,
"make_byte_level_fsm"
,
"GrammarMatcher"
,
"GrammarMatcherInitContext"
,
"GrammarMatcherInitContextCache"
,
]
]
python/sglang/srt/constrained/bnf_cache.py
0 → 100644
View file @
b77a02cd
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
from
typing
import
Tuple
from
transformers
import
AutoTokenizer
from
sglang.srt.constrained
import
(
GrammarMatcher
,
GrammarMatcherInitContext
,
GrammarMatcherInitContextCache
,
)
MAX_ROLLBACK_TOKENS
=
10
class
BNFCache
:
grammar_cache
:
GrammarMatcherInitContextCache
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
skip_tokenizer_init
=
False
,
whitespace_patterns
=
None
,
):
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
if
skip_tokenizer_init
:
return
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_path
,
**
tokenizer_args_dict
)
self
.
grammar_cache
=
GrammarMatcherInitContextCache
(
tokenizer_or_vocab
=
tokenizer
)
def
get_context
(
self
,
key
:
Tuple
[
str
,
str
])
->
GrammarMatcherInitContext
:
key_type
,
key_string
=
key
if
key_type
==
"json"
:
return
self
.
grammar_cache
.
get_init_context_for_json_schema
(
key_string
)
elif
key_type
==
"regex"
:
raise
ValueError
(
f
"regex hasn't been supported by xgrammar yet"
)
else
:
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
def
query
(
self
,
key
:
Tuple
[
str
,
str
],
vocab_size
:
int
)
->
GrammarMatcher
:
ctx
=
self
.
get_context
(
key
)
return
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
mask_vocab_size
=
vocab_size
)
python/sglang/srt/constrained/grammar.py
0 → 100644
View file @
b77a02cd
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
import
logging
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
sglang.srt.constrained
import
GrammarMatcher
,
RegexGuide
from
sglang.srt.constrained.bnf_cache
import
BNFCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
,
JumpForwardMap
# from sglang.srt.managers.schedule_batch import Req
logger
=
logging
.
getLogger
(
__name__
)
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
class
XGrammarJump
:
pass
class
JumpHelper
:
data
:
Union
[
List
,
str
]
state
:
int
suffix_ids
:
List
[
int
]
def
__init__
(
self
,
data
:
Union
[
List
,
str
]
=
""
,
state
:
int
=
-
1
,
suffix_ids
=
[]
)
->
None
:
self
.
data
=
data
self
.
state
=
state
self
.
suffix_ids
=
suffix_ids
def
can_jump
(
self
):
return
len
(
self
.
data
)
>
0
class
Grammar
:
grammar
:
Union
[
GrammarMatcher
,
Tuple
[
RegexGuide
,
int
]]
jump_map
:
Union
[
XGrammarJump
,
JumpForwardMap
,
None
]
def
__init__
(
self
,
grammar
:
Union
[
GrammarMatcher
,
Tuple
[
RegexGuide
,
int
]],
jump_map
:
Union
[
XGrammarJump
,
JumpForwardMap
,
None
],
)
->
None
:
self
.
grammar
=
grammar
self
.
jump_map
=
jump_map
def
accept_token
(
self
,
token
:
int
):
if
isinstance
(
self
.
grammar
,
GrammarMatcher
):
assert
self
.
grammar
.
accept_token
(
token
)
else
:
guide
,
state
=
self
.
grammar
self
.
grammar
=
guide
,
guide
.
get_next_state
(
state
,
token
)
def
try_jump
(
self
,
tokenizer
)
->
JumpHelper
:
if
isinstance
(
self
.
jump_map
,
XGrammarJump
):
assert
isinstance
(
self
.
grammar
,
GrammarMatcher
)
return
JumpHelper
(
self
.
grammar
.
find_jump_forward_string
())
elif
isinstance
(
self
.
jump_map
,
JumpForwardMap
):
assert
isinstance
(
self
.
grammar
,
Tuple
)
_
,
state
=
self
.
grammar
jump_forward_bytes
=
self
.
jump_map
.
jump_forward_byte
(
state
)
if
jump_forward_bytes
is
None
or
len
(
jump_forward_bytes
)
==
0
:
return
JumpHelper
()
# can't jump
# preprocess the jump forward string
suffix_bytes
=
[]
continuation_range
=
range
(
0x80
,
0xC0
)
cur_state
=
state
while
(
len
(
jump_forward_bytes
)
and
jump_forward_bytes
[
0
][
0
]
in
continuation_range
):
# continuation bytes
byte_edge
=
jump_forward_bytes
.
pop
(
0
)
suffix_bytes
.
append
(
byte_edge
[
0
])
cur_state
=
byte_edge
[
1
]
suffix_tokens
=
[
f
"<0x
{
hex
(
b
)[
2
:].
upper
()
}
>"
for
b
in
suffix_bytes
]
suffix_ids
=
tokenizer
.
convert_tokens_to_ids
(
suffix_tokens
)
return
JumpHelper
(
suffix_ids
,
cur_state
,
suffix_bytes
)
else
:
return
JumpHelper
()
# can't jump
def
jump_forward_str_state
(
self
,
helper
:
JumpHelper
)
->
Tuple
[
str
,
int
]:
if
isinstance
(
helper
.
data
,
str
):
return
helper
.
data
,
-
1
else
:
assert
isinstance
(
self
.
jump_map
,
JumpForwardMap
)
return
self
.
jump_map
.
jump_forward_symbol
(
helper
.
state
)
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
if
isinstance
(
self
.
grammar
,
GrammarMatcher
):
k
=
0
for
i
,
old_id
in
enumerate
(
old_output_ids
):
if
old_id
==
new_output_ids
[
i
]:
k
=
i
+
1
else
:
break
# rollback to the last token that is the same
if
k
<
len
(
old_output_ids
):
self
.
grammar
.
rollback
(
len
(
old_output_ids
)
-
k
)
for
i
in
range
(
k
,
len
(
new_output_ids
)):
assert
self
.
grammar
.
accept_token
(
new_output_ids
[
i
])
else
:
self
.
grammar
=
self
.
grammar
[
0
],
next_state
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
vocab_size
:
int
):
if
isinstance
(
self
.
grammar
,
GrammarMatcher
):
# Note that this bitmask is a bitset, not bool
bitmask
=
self
.
grammar
.
find_next_token_bitmask
()
# Mask the tokens that are not allowed
vocab_mask
[
self
.
grammar
.
get_rejected_tokens_from_bitmask
(
bitmask
,
vocab_size
)
]
=
1
else
:
guide
,
state
=
self
.
grammar
vocab_mask
.
fill_
(
1
)
vocab_mask
[
guide
.
get_next_instruction
(
state
).
tokens
]
=
0
class
GrammarCache
:
grammar_cache
:
Union
[
BNFCache
,
FSMCache
]
jump_cache
:
Union
[
XGrammarJump
,
JumpForwardCache
,
None
]
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
skip_tokenizer_init
=
False
,
whitespace_patterns
=
None
,
backend
=
None
,
allow_jump
=
False
,
):
if
backend
==
"xgrammar"
:
self
.
grammar_cache
=
BNFCache
(
tokenizer_path
=
tokenizer_path
,
tokenizer_args_dict
=
tokenizer_args_dict
,
skip_tokenizer_init
=
skip_tokenizer_init
,
whitespace_patterns
=
whitespace_patterns
,
)
self
.
jump_cache
=
XGrammarJump
()
if
allow_jump
else
None
else
:
assert
backend
==
"outlines"
self
.
grammar_cache
=
FSMCache
(
tokenizer_path
=
tokenizer_path
,
tokenizer_args_dict
=
tokenizer_args_dict
,
skip_tokenizer_init
=
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
whitespace_patterns
,
enable
=
True
,
)
self
.
jump_cache
=
JumpForwardCache
()
if
allow_jump
else
None
def
query
(
self
,
key
:
Tuple
[
str
,
str
],
vocab_size
:
int
)
->
Grammar
:
if
isinstance
(
self
.
grammar_cache
,
BNFCache
):
assert
not
isinstance
(
self
.
jump_cache
,
JumpForwardCache
)
return
Grammar
(
self
.
grammar_cache
.
query
(
key
,
vocab_size
),
self
.
jump_cache
)
else
:
jump_map
=
None
guide
,
regex
=
self
.
grammar_cache
.
query
(
key
)
if
isinstance
(
self
.
jump_cache
,
JumpForwardCache
):
jump_map
=
self
.
jump_cache
.
query
(
regex
)
return
Grammar
((
guide
,
0
),
jump_map
)
def
reset
(
self
):
if
isinstance
(
self
.
grammar_cache
,
FSMCache
):
self
.
grammar_cache
.
reset
()
if
isinstance
(
self
.
jump_cache
,
JumpForwardCache
):
self
.
jump_cache
.
reset
()
python/sglang/srt/managers/schedule_batch.py
View file @
b77a02cd
...
@@ -37,8 +37,7 @@ import torch
...
@@ -37,8 +37,7 @@ import torch
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.grammar
import
Grammar
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
...
@@ -247,9 +246,7 @@ class Req:
...
@@ -247,9 +246,7 @@ class Req:
self
.
embedding
=
None
self
.
embedding
=
None
# Constrained decoding
# Constrained decoding
self
.
regex_fsm
:
RegexGuide
=
None
self
.
grammar
:
Optional
[
Grammar
]
=
None
self
.
regex_fsm_state
:
int
=
0
self
.
jump_forward_map
:
JumpForwardMap
=
None
# For Qwen2-VL
# For Qwen2-VL
self
.
mrope_position_delta
=
[]
# use mutable object
self
.
mrope_position_delta
=
[]
# use mutable object
...
@@ -359,6 +356,8 @@ class Req:
...
@@ -359,6 +356,8 @@ class Req:
return
return
def
jump_forward_and_retokenize
(
self
,
jump_forward_str
,
next_state
):
def
jump_forward_and_retokenize
(
self
,
jump_forward_str
,
next_state
):
assert
self
.
grammar
is
not
None
and
self
.
tokenizer
is
not
None
if
self
.
origin_input_text
is
None
:
if
self
.
origin_input_text
is
None
:
# Recovering text can only use unpadded ids
# Recovering text can only use unpadded ids
self
.
origin_input_text
=
self
.
tokenizer
.
decode
(
self
.
origin_input_text
=
self
.
tokenizer
.
decode
(
...
@@ -398,7 +397,8 @@ class Req:
...
@@ -398,7 +397,8 @@ class Req:
self
.
surr_offset
=
self
.
read_offset
-
i
self
.
surr_offset
=
self
.
read_offset
-
i
break
break
self
.
regex_fsm_state
=
next_state
# update the inner state of the grammar
self
.
grammar
.
jump_and_retokenize
(
old_output_ids
,
self
.
output_ids
,
next_state
)
if
self
.
return_logprob
:
if
self
.
return_logprob
:
# For fast-forward part's logprobs
# For fast-forward part's logprobs
...
@@ -468,8 +468,8 @@ class ScheduleBatch:
...
@@ -468,8 +468,8 @@ class ScheduleBatch:
# Stream
# Stream
has_stream
:
bool
=
False
has_stream
:
bool
=
False
# Has
regex
# Has
grammar
has_
regex
:
bool
=
False
has_
grammar
:
bool
=
False
# device
# device
device
:
str
=
"cuda"
device
:
str
=
"cuda"
...
@@ -477,7 +477,7 @@ class ScheduleBatch:
...
@@ -477,7 +477,7 @@ class ScheduleBatch:
@
classmethod
@
classmethod
def
init_new
(
def
init_new
(
cls
,
cls
,
reqs
,
reqs
:
List
[
Req
]
,
req_to_token_pool
,
req_to_token_pool
,
token_to_kv_pool
,
token_to_kv_pool
,
tree_cache
,
tree_cache
,
...
@@ -491,7 +491,7 @@ class ScheduleBatch:
...
@@ -491,7 +491,7 @@ class ScheduleBatch:
model_config
=
model_config
,
model_config
=
model_config
,
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_
regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
),
has_
grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
device
=
req_to_token_pool
.
device
,
device
=
req_to_token_pool
.
device
,
)
)
...
@@ -803,26 +803,10 @@ class ScheduleBatch:
...
@@ -803,26 +803,10 @@ class ScheduleBatch:
keep_indices
=
set
(
i
for
i
in
range
(
len
(
self
.
reqs
)))
keep_indices
=
set
(
i
for
i
in
range
(
len
(
self
.
reqs
)))
for
i
,
req
in
enumerate
(
self
.
reqs
):
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
jump_forward_map
is
not
None
:
if
req
.
grammar
is
not
None
:
jump_forward_bytes
=
req
.
jump_forward_map
.
jump_forward_byte
(
jump_helper
=
req
.
grammar
.
try_jump
(
req
.
tokenizer
)
req
.
regex_fsm_state
if
jump_helper
.
can_jump
():
)
suffix_ids
=
jump_helper
.
suffix_ids
if
jump_forward_bytes
is
not
None
and
len
(
jump_forward_bytes
)
>
1
:
suffix_bytes
=
[]
continuation_range
=
range
(
0x80
,
0xC0
)
cur_state
=
req
.
regex_fsm_state
while
(
len
(
jump_forward_bytes
)
and
jump_forward_bytes
[
0
][
0
]
in
continuation_range
):
# continuation bytes
byte_edge
=
jump_forward_bytes
.
pop
(
0
)
suffix_bytes
.
append
(
byte_edge
[
0
])
cur_state
=
byte_edge
[
1
]
suffix_tokens
=
[
f
"<0x
{
hex
(
b
)[
2
:].
upper
()
}
>"
for
b
in
suffix_bytes
]
suffix_ids
=
req
.
tokenizer
.
convert_tokens_to_ids
(
suffix_tokens
)
# Current ids, for cache and revert
# Current ids, for cache and revert
cur_all_ids
=
tuple
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
cur_all_ids
=
tuple
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
cur_output_ids
=
req
.
output_ids
cur_output_ids
=
req
.
output_ids
...
@@ -836,10 +820,8 @@ class ScheduleBatch:
...
@@ -836,10 +820,8 @@ class ScheduleBatch:
(
(
jump_forward_str
,
jump_forward_str
,
next_state
,
next_state
,
)
=
req
.
jump_forward_
ma
p
.
jump_forward_s
ymbol
(
cur_state
)
)
=
req
.
gram
ma
r
.
jump_forward_s
tr_state
(
jump_helper
)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str
=
new_text
+
jump_forward_str
jump_forward_str
=
new_text
+
jump_forward_str
if
not
req
.
jump_forward_and_retokenize
(
if
not
req
.
jump_forward_and_retokenize
(
jump_forward_str
,
next_state
jump_forward_str
,
next_state
...
@@ -946,7 +928,7 @@ class ScheduleBatch:
...
@@ -946,7 +928,7 @@ class ScheduleBatch:
self
.
top_logprobs_nums
=
None
self
.
top_logprobs_nums
=
None
self
.
has_stream
=
any
(
req
.
stream
for
req
in
self
.
reqs
)
self
.
has_stream
=
any
(
req
.
stream
for
req
in
self
.
reqs
)
self
.
has_
regex
=
any
(
req
.
regex_fsm
for
req
in
self
.
reqs
)
self
.
has_
grammar
=
any
(
req
.
grammar
for
req
in
self
.
reqs
)
self
.
sampling_info
.
filter_batch
(
keep_indices
,
new_indices
)
self
.
sampling_info
.
filter_batch
(
keep_indices
,
new_indices
)
...
@@ -979,7 +961,7 @@ class ScheduleBatch:
...
@@ -979,7 +961,7 @@ class ScheduleBatch:
self
.
return_logprob
=
self
.
return_logprob
or
other
.
return_logprob
self
.
return_logprob
=
self
.
return_logprob
or
other
.
return_logprob
self
.
has_stream
=
self
.
has_stream
or
other
.
has_stream
self
.
has_stream
=
self
.
has_stream
or
other
.
has_stream
self
.
has_
regex
=
self
.
has_
regex
or
other
.
has_
regex
self
.
has_
grammar
=
self
.
has_
grammar
or
other
.
has_
grammar
def
get_model_worker_batch
(
self
):
def
get_model_worker_batch
(
self
):
if
self
.
forward_mode
.
is_decode
():
if
self
.
forward_mode
.
is_decode
():
...
@@ -989,13 +971,10 @@ class ScheduleBatch:
...
@@ -989,13 +971,10 @@ class ScheduleBatch:
extend_prefix_lens
=
self
.
prefix_lens
extend_prefix_lens
=
self
.
prefix_lens
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
if
self
.
has_regex
:
if
self
.
has_grammar
:
self
.
sampling_info
.
regex_fsms
=
[
req
.
regex_fsm
for
req
in
self
.
reqs
]
self
.
sampling_info
.
grammars
=
[
req
.
grammar
for
req
in
self
.
reqs
]
self
.
sampling_info
.
regex_fsm_states
=
[
req
.
regex_fsm_state
for
req
in
self
.
reqs
]
else
:
else
:
self
.
sampling_info
.
regex_fsm
s
=
None
self
.
sampling_info
.
grammar
s
=
None
global
bid
global
bid
bid
+=
1
bid
+=
1
...
...
python/sglang/srt/managers/scheduler.py
View file @
b77a02cd
...
@@ -29,8 +29,7 @@ import zmq
...
@@ -29,8 +29,7 @@ import zmq
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.grammar
import
GrammarCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
...
@@ -225,17 +224,20 @@ class Scheduler:
...
@@ -225,17 +224,20 @@ class Scheduler:
)
)
# Init the FSM cache for constrained generation
# Init the FSM cache for constrained generation
self
.
grammar_cache
=
None
if
not
server_args
.
skip_tokenizer_init
:
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm
_cache
=
FSM
Cache
(
self
.
grammar
_cache
=
Grammar
Cache
(
server_args
.
tokenizer_path
,
server_args
.
tokenizer_path
,
{
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
whitespace_patterns
=
server_args
.
constrained_json_whitespace_pattern
,
backend
=
server_args
.
grammar_backend
,
allow_jump
=
not
server_args
.
disable_regex_jump_forward
,
)
)
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
# Init new token estimation
assert
(
assert
(
...
@@ -402,22 +404,20 @@ class Scheduler:
...
@@ -402,22 +404,20 @@ class Scheduler:
# By default, only return the logprobs for output tokens
# By default, only return the logprobs for output tokens
req
.
logprob_start_len
=
len
(
recv_req
.
input_ids
)
-
1
req
.
logprob_start_len
=
len
(
recv_req
.
input_ids
)
-
1
# Init regex FSM
# Init regex FSM
or BNF
if
(
if
(
req
.
sampling_params
.
json_schema
is
not
None
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
):
):
assert
self
.
grammar_cache
is
not
None
if
req
.
sampling_params
.
json_schema
is
not
None
:
if
req
.
sampling_params
.
json_schema
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
req
.
grammar
=
self
.
grammar_cache
.
query
(
(
"json"
,
req
.
sampling_params
.
json_schema
)
(
"json"
,
req
.
sampling_params
.
json_schema
),
self
.
model_config
.
vocab_size
,
)
)
elif
req
.
sampling_params
.
regex
is
not
None
:
elif
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
req
.
grammar
=
self
.
grammar_cache
.
query
(
(
"regex"
,
req
.
sampling_params
.
regex
)
(
"regex"
,
req
.
sampling_params
.
regex
),
self
.
model_config
.
vocab_size
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
computed_regex_string
)
)
# Truncate prompts that are too long
# Truncate prompts that are too long
...
@@ -796,10 +796,8 @@ class Scheduler:
...
@@ -796,10 +796,8 @@ class Scheduler:
elif
not
batch
.
decoding_reqs
or
req
not
in
batch
.
decoding_reqs
:
elif
not
batch
.
decoding_reqs
or
req
not
in
batch
.
decoding_reqs
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
.
regex_fsm
is
not
None
:
if
req
.
grammar
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
grammar
.
accept_token
(
next_token_ids
[
i
])
req
.
regex_fsm_state
,
next_token_ids
[
i
]
)
if
req
.
return_logprob
:
if
req
.
return_logprob
:
logprob_pt
+=
self
.
add_logprob_return_values
(
logprob_pt
+=
self
.
add_logprob_return_values
(
...
@@ -855,10 +853,8 @@ class Scheduler:
...
@@ -855,10 +853,8 @@ class Scheduler:
req
.
output_ids
.
append
(
next_token_id
)
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
if
req
.
grammar
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
...
@@ -1056,7 +1052,9 @@ class Scheduler:
...
@@ -1056,7 +1052,9 @@ class Scheduler:
):
):
self
.
tree_cache
.
reset
()
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
if
self
.
grammar_cache
is
not
None
:
self
.
grammar_cache
.
reset
()
# TODO(dark): reset the bnf cache
self
.
req_to_token_pool
.
clear
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
b77a02cd
...
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
...
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
import
torch
import
torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained
.grammar
import
Grammar
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
@@ -29,11 +29,9 @@ class SamplingBatchInfo:
...
@@ -29,11 +29,9 @@ class SamplingBatchInfo:
# Bias Tensors
# Bias Tensors
vocab_size
:
int
vocab_size
:
int
logit_bias
:
torch
.
Tensor
=
None
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
# FSM states
grammars
:
Optional
[
List
[
Optional
[
Grammar
]]]
=
None
regex_fsms
:
List
[
RegexGuide
]
=
None
regex_fsm_states
:
List
[
int
]
=
None
# Penalizer
# Penalizer
penalizer_orchestrator
:
Optional
[
penaltylib
.
BatchedPenalizerOrchestrator
]
=
None
penalizer_orchestrator
:
Optional
[
penaltylib
.
BatchedPenalizerOrchestrator
]
=
None
...
@@ -136,8 +134,7 @@ class SamplingBatchInfo:
...
@@ -136,8 +134,7 @@ class SamplingBatchInfo:
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
def
update_regex_vocab_mask
(
self
):
def
update_regex_vocab_mask
(
self
):
has_regex
=
self
.
regex_fsms
and
any
(
regex_fsm
for
regex_fsm
in
self
.
regex_fsms
)
if
not
self
.
grammars
or
not
any
(
grammar
for
grammar
in
self
.
grammars
):
if
not
has_regex
:
self
.
vocab_mask
=
None
self
.
vocab_mask
=
None
return
return
...
@@ -147,12 +144,9 @@ class SamplingBatchInfo:
...
@@ -147,12 +144,9 @@ class SamplingBatchInfo:
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
for
i
,
regex_fsm
in
enumerate
(
self
.
regex_fsms
):
for
i
,
grammar
in
enumerate
(
self
.
grammars
):
if
regex_fsm
is
not
None
:
if
grammar
is
not
None
:
self
.
vocab_mask
[
i
].
fill_
(
1
)
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
[
i
],
self
.
vocab_size
)
self
.
vocab_mask
[
i
][
regex_fsm
.
get_next_instruction
(
self
.
regex_fsm_states
[
i
]).
tokens
]
=
0
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
if
self
.
penalizer_orchestrator
:
if
self
.
penalizer_orchestrator
:
...
...
python/sglang/srt/server_args.py
View file @
b77a02cd
...
@@ -102,6 +102,7 @@ class ServerArgs:
...
@@ -102,6 +102,7 @@ class ServerArgs:
# Kernel backend
# Kernel backend
attention_backend
:
Optional
[
str
]
=
None
attention_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
grammar_backend
:
Optional
[
str
]
=
"outlines"
# Optimization/debug options
# Optimization/debug options
disable_flashinfer
:
bool
=
False
disable_flashinfer
:
bool
=
False
...
@@ -537,6 +538,13 @@ class ServerArgs:
...
@@ -537,6 +538,13 @@ class ServerArgs:
default
=
ServerArgs
.
sampling_backend
,
default
=
ServerArgs
.
sampling_backend
,
help
=
"Choose the kernels for sampling layers."
,
help
=
"Choose the kernels for sampling layers."
,
)
)
parser
.
add_argument
(
"--grammar-backend"
,
type
=
str
,
choices
=
[
"xgrammar"
,
"outlines"
],
default
=
ServerArgs
.
grammar_backend
,
help
=
"Choose the backend for constrained decoding."
,
)
# Optimization/debug options
# Optimization/debug options
parser
.
add_argument
(
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment