Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
54479d6f
Unverified
Commit
54479d6f
authored
Nov 13, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 13, 2024
Browse files
Fix grammar backend for tensor parallelism (#2020)
parent
ba069a24
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
251 additions
and
227 deletions
+251
-227
python/sglang/srt/constrained/base_grammar_backend.py
python/sglang/srt/constrained/base_grammar_backend.py
+72
-0
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+22
-60
python/sglang/srt/constrained/outlines_jump_forward.py
python/sglang/srt/constrained/outlines_jump_forward.py
+82
-98
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+28
-41
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+45
-27
No files found.
python/sglang/srt/constrained/base_
tool_cache
.py
→
python/sglang/srt/constrained/base_
grammar_backend
.py
View file @
54479d6f
...
...
@@ -13,90 +13,60 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Base cache class for
constrained decoding
tools
."""
"""
The baseclass of backends for grammar-guided
constrained decoding."""
import
time
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
threading
import
Event
,
Lock
from
typing
import
Any
,
Dict
,
Tuple
from
typing
import
Any
,
Optional
,
Tuple
@
dataclass
class
MapEntry
:
event
:
Event
class
CacheEntry
:
value
:
Any
event
:
Event
def
__iter__
(
self
):
return
iter
((
self
.
event
,
self
.
value
))
class
BaseGrammarObject
:
pass
class
BaseToolCache
:
def
__init__
(
self
,
enable
=
True
):
self
.
enable
:
bool
=
enable
self
.
cache
:
Dict
[
str
,
MapEntry
]
=
{}
self
.
metrics
:
Dict
[
str
,
Any
]
=
{}
self
.
lock_cache
:
Lock
=
Lock
()
self
.
lock_metrics
:
Lock
=
Lock
()
self
.
reset
()
class
BaseGrammarBackend
:
def
__init__
(
self
):
self
.
executor
=
ThreadPoolExecutor
()
self
.
cache
=
{}
self
.
cache_lock
=
Lock
()
def
reset
(
self
):
with
self
.
lock_cache
:
self
.
cache
=
{}
with
self
.
lock_metrics
:
self
.
metrics
=
{
"total"
:
0
,
"hit"
:
0
,
"avg_init_time"
:
0
}
def
_init_with_timer
(
self
,
key
)
->
Tuple
[
Any
,
float
]:
start
=
time
.
monotonic
()
val
=
self
.
init_value
(
key
)
init_time
=
time
.
monotonic
()
-
start
return
val
,
init_time
def
update_time
(
self
,
init_time
):
with
self
.
lock_metrics
:
curr_total
=
self
.
metrics
[
"total"
]
new_total
=
curr_total
+
1
# Update average init time without old_avg * old_total to avoid overflow.
self
.
metrics
[
"avg_init_time"
]
=
(
init_time
/
new_total
)
+
(
curr_total
/
new_total
)
*
self
.
metrics
[
"avg_init_time"
]
def
query
(
self
,
key
):
if
not
self
.
enable
:
value
,
init_time
=
self
.
_init_with_timer
(
key
)
self
.
update_time
(
init_time
)
return
value
with
self
.
lock_cache
:
def
init_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
BaseGrammarObject
:
with
self
.
cache_lock
:
if
key
in
self
.
cache
:
entry
=
self
.
cache
[
key
]
cache_hit
=
True
entry
=
self
.
cache
[
key
]
else
:
entry
=
MapEntry
(
Event
(),
None
)
self
.
cache
[
key
]
=
entry
cache_hit
=
False
with
self
.
lock_metrics
:
self
.
metrics
[
"total"
]
+=
1
if
cache_hit
:
self
.
metrics
[
"hit"
]
+=
1
entry
=
CacheEntry
(
None
,
Event
())
self
.
cache
[
key
]
=
entry
if
cache_hit
:
entry
.
event
.
wait
()
else
:
entry
.
value
,
init_time
=
self
.
_init_with_timer
(
key
)
self
.
update_time
(
init_time
)
entry
.
value
=
self
.
init_value_impl
(
key
)
entry
.
event
.
set
()
return
entry
.
value
return
entry
.
value
.
copy
()
def
init_value
(
self
,
key
)
:
def
init_value
_impl
(
self
,
key
:
Tuple
[
str
,
str
])
->
BaseGrammarObject
:
raise
NotImplementedError
()
def
get_cache_hit_rate
(
self
):
with
self
.
lock_metrics
:
return
self
.
metrics
[
"hit"
]
/
max
(
self
.
metrics
[
"total"
],
1
)
def
get_cached_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
BaseGrammarObject
]:
with
self
.
cache_lock
:
entry
=
self
.
cache
.
get
(
key
)
if
not
entry
or
not
entry
.
event
.
is_set
():
return
None
return
self
.
cache
[
key
].
value
.
copy
()
def
get_avg_init_time
(
self
):
with
self
.
lock_metrics
:
return
self
.
metrics
[
"avg_init_time"
]
def
get_future_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Future
:
return
self
.
executor
.
submit
(
self
.
init_value
,
key
)
def
reset
(
self
):
with
self
.
cache_lock
:
self
.
cache
.
clear
()
python/sglang/srt/constrained/outlines_backend.py
View file @
54479d6f
...
...
@@ -17,20 +17,17 @@ limitations under the License.
import
json
import
logging
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
interegular
import
InvalidSyntax
,
parse_pattern
from
outlines.fsm.guide
import
RegexGuide
from
outlines.models.transformers
import
TransformerTokenizer
from
pydantic
import
BaseModel
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
from
sglang.srt.constrained.outlines_jump_forward
import
(
OutlinesJumpForwardCache
,
OutlinesJumpForwardMap
,
from
sglang.srt.constrained.base_grammar_backend
import
(
BaseGrammarBackend
,
BaseGrammarObject
,
)
from
sglang.srt.constrained.outlines_jump_forward
import
OutlinesJumpForwardMap
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -41,6 +38,7 @@ except ImportError:
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
# which only accepts string schema as input.
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
pydantic
import
BaseModel
def
build_regex_from_object
(
object
:
Union
[
str
,
BaseModel
,
Dict
],
whitespace_pattern
:
Optional
[
str
]
=
None
...
...
@@ -54,16 +52,15 @@ except ImportError:
return
build_regex_from_schema
(
schema
,
whitespace_pattern
)
class
OutlinesGrammar
:
class
OutlinesGrammar
(
BaseGrammarObject
)
:
def
__init__
(
self
,
guide
:
RegexGuide
,
state
:
int
,
jump_forward_map
:
Union
[
OutlinesJumpForwardMap
,
None
],
)
->
None
:
self
.
guide
=
guide
self
.
state
=
state
self
.
jump_forward_map
=
jump_forward_map
self
.
state
=
0
def
accept_token
(
self
,
token
:
int
):
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
...
...
@@ -105,46 +102,18 @@ class OutlinesGrammar:
vocab_mask
.
fill_
(
1
)
vocab_mask
[
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
]
=
0
def
copy
(
self
):
return
OutlinesGrammar
(
self
.
guide
,
self
.
jump_forward_map
)
class
OutlinesGrammarBackend
:
class
OutlinesGrammarBackend
(
BaseGrammarBackend
):
def
__init__
(
self
,
tokenizer
,
whitespace_pattern
s
:
bool
,
whitespace_pattern
:
bool
,
allow_jump_forward
:
bool
,
):
self
.
executor
=
ThreadPoolExecutor
()
self
.
grammar_cache
=
OutlinesCache
(
tokenizer
,
whitespace_pattern
=
whitespace_patterns
,
)
self
.
jump_forward_cache
=
(
OutlinesJumpForwardCache
()
if
allow_jump_forward
else
None
)
def
_query
(
self
,
key
:
Tuple
[
str
,
str
])
->
OutlinesGrammar
:
guide
,
regex
=
self
.
grammar_cache
.
query
(
key
)
jump_forward_map
=
(
self
.
jump_forward_cache
.
query
(
regex
)
if
self
.
jump_forward_cache
else
None
)
return
OutlinesGrammar
(
guide
,
0
,
jump_forward_map
)
def
query
(
self
,
key
:
Tuple
[
str
,
str
])
->
Future
:
return
self
.
executor
.
submit
(
self
.
_query
,
key
)
def
reset
(
self
):
self
.
grammar_cache
.
reset
()
if
self
.
jump_forward_cache
:
self
.
jump_forward_cache
.
reset
()
class
OutlinesCache
(
BaseToolCache
):
def
__init__
(
self
,
tokenizer
,
whitespace_pattern
=
None
,
):
super
().
__init__
(
enable
=
True
)
super
().
__init__
()
try
:
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer
)
...
...
@@ -167,9 +136,10 @@ class OutlinesCache(BaseToolCache):
self
.
outlines_tokenizer
.
vocabulary
=
(
self
.
outlines_tokenizer
.
tokenizer
.
get_vocab
()
)
self
.
allow_jump_forward
=
allow_jump_forward
self
.
whitespace_pattern
=
whitespace_pattern
def
init_value
(
self
,
key
)
:
def
init_value
_impl
(
self
,
key
:
Tuple
[
str
,
str
])
->
OutlinesGrammar
:
key_type
,
key_string
=
key
if
key_type
==
"json"
:
try
:
...
...
@@ -186,18 +156,10 @@ class OutlinesCache(BaseToolCache):
regex
=
key_string
else
:
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
try
:
parse_pattern
(
regex
)
except
InvalidSyntax
as
e
:
logger
.
warning
(
f
"skip invalid regex guide:
{
regex
=
}
,
{
e
=
}
"
)
return
None
,
regex
ret
=
RegexGuide
(
regex
,
self
.
outlines_tokenizer
),
regex
return
ret
def
_query
(
self
,
key
:
Tuple
[
str
,
str
]):
guide
,
regex
=
self
.
grammar_cache
.
query
(
key
)
jump_forward_map
=
(
self
.
jump_forward_cache
.
query
(
regex
)
if
self
.
jump_forward_cache
else
None
)
return
OutlinesGrammar
(
guide
,
0
,
jump_forward_map
)
guide
=
RegexGuide
(
regex
,
self
.
outlines_tokenizer
)
if
self
.
allow_jump_forward
:
jump_forward_map
=
OutlinesJumpForwardMap
(
regex
)
else
:
jump_forward_map
=
None
return
OutlinesGrammar
(
guide
,
jump_forward_map
)
python/sglang/srt/constrained/outlines_jump_forward.py
View file @
54479d6f
...
...
@@ -27,8 +27,6 @@ from interegular import InvalidSyntax
from
outlines.caching
import
cache
as
disk_cache
from
outlines.fsm.regex
import
FSMInfo
,
make_byte_level_fsm
,
make_deterministic_fsm
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -42,92 +40,90 @@ class JumpEdge:
byte_next_state
:
int
=
None
@
disk_cache
()
def
init_state_to_jump_forward
(
regex_string
):
try
:
regex_pattern
=
interegular
.
parse_pattern
(
regex_string
)
except
InvalidSyntax
as
e
:
logger
.
warning
(
f
"skip invalid regex:
{
regex_string
}
,
{
e
=
}
"
)
return
byte_fsm
=
make_byte_level_fsm
(
regex_pattern
.
to_fsm
().
reduce
(),
keep_utf8
=
True
)
regex_fsm
,
_
=
make_deterministic_fsm
(
byte_fsm
)
fsm_info
:
FSMInfo
=
regex_fsm
.
fsm_info
symbol_to_id
=
fsm_info
.
alphabet_symbol_mapping
id_to_symbol
=
{}
for
symbol
,
id_
in
symbol_to_id
.
items
():
id_to_symbol
.
setdefault
(
id_
,
[]).
append
(
symbol
)
transitions
=
fsm_info
.
transitions
outgoings_ct
=
defaultdict
(
int
)
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
for
s
in
fsm_info
.
finals
:
outgoings_ct
[
s
]
=
1
state_to_jump_forward
=
{}
for
(
state
,
id_
),
next_state
in
transitions
.
items
():
if
id_
==
fsm_info
.
alphabet_anything_value
:
# Arbitrarily symbol cannot be recognized as jump forward
continue
symbols
=
id_to_symbol
[
id_
]
for
c
in
symbols
:
if
len
(
c
)
>
1
:
# Skip byte level transitions like c = "5E"
continue
outgoings_ct
[
state
]
+=
1
if
outgoings_ct
[
state
]
>
1
:
if
state
in
state_to_jump_forward
:
del
state_to_jump_forward
[
state
]
break
state_to_jump_forward
[
state
]
=
JumpEdge
(
symbol
=
c
,
symbol_next_state
=
next_state
,
)
# Process the byte level jump forward
outgoings_ct
=
defaultdict
(
int
)
for
s
in
fsm_info
.
finals
:
outgoings_ct
[
s
]
=
1
for
(
state
,
id_
),
next_state
in
transitions
.
items
():
if
id_
==
fsm_info
.
alphabet_anything_value
:
continue
symbols
=
id_to_symbol
[
id_
]
for
c
in
symbols
:
byte_
=
None
if
len
(
c
)
==
1
and
ord
(
c
)
<
0x80
:
# ASCII character
byte_
=
ord
(
c
)
elif
len
(
c
)
>
1
:
# FIXME: This logic is due to the leading \x00
# https://github.com/outlines-dev/outlines/pull/930
byte_
=
int
(
symbols
[
0
][
1
:],
16
)
if
byte_
is
not
None
:
outgoings_ct
[
state
]
+=
1
if
outgoings_ct
[
state
]
>
1
:
if
state
in
state_to_jump_forward
:
del
state_to_jump_forward
[
state
]
break
e
=
state_to_jump_forward
.
get
(
state
,
JumpEdge
())
e
.
byte
=
byte_
e
.
byte_next_state
=
next_state
state_to_jump_forward
[
state
]
=
e
return
state_to_jump_forward
class
OutlinesJumpForwardMap
:
def
__init__
(
self
,
regex_string
):
@
disk_cache
()
def
_init_state_to_jump_forward
(
regex_string
):
try
:
regex_pattern
=
interegular
.
parse_pattern
(
regex_string
)
except
InvalidSyntax
as
e
:
logger
.
warning
(
f
"skip invalid regex:
{
regex_string
}
,
{
e
=
}
"
)
self
.
state_to_jump_forward
=
None
return
byte_fsm
=
make_byte_level_fsm
(
regex_pattern
.
to_fsm
().
reduce
(),
keep_utf8
=
True
)
regex_fsm
,
_
=
make_deterministic_fsm
(
byte_fsm
)
fsm_info
:
FSMInfo
=
regex_fsm
.
fsm_info
symbol_to_id
=
fsm_info
.
alphabet_symbol_mapping
id_to_symbol
=
{}
for
symbol
,
id_
in
symbol_to_id
.
items
():
id_to_symbol
.
setdefault
(
id_
,
[]).
append
(
symbol
)
transitions
=
fsm_info
.
transitions
outgoings_ct
=
defaultdict
(
int
)
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
for
s
in
fsm_info
.
finals
:
outgoings_ct
[
s
]
=
1
state_to_jump_forward
=
{}
for
(
state
,
id_
),
next_state
in
transitions
.
items
():
if
id_
==
fsm_info
.
alphabet_anything_value
:
# Arbitrarily symbol cannot be recognized as jump forward
continue
symbols
=
id_to_symbol
[
id_
]
for
c
in
symbols
:
if
len
(
c
)
>
1
:
# Skip byte level transitions like c = "5E"
continue
outgoings_ct
[
state
]
+=
1
if
outgoings_ct
[
state
]
>
1
:
if
state
in
state_to_jump_forward
:
del
state_to_jump_forward
[
state
]
break
state_to_jump_forward
[
state
]
=
JumpEdge
(
symbol
=
c
,
symbol_next_state
=
next_state
,
)
# Process the byte level jump forward
outgoings_ct
=
defaultdict
(
int
)
for
s
in
fsm_info
.
finals
:
outgoings_ct
[
s
]
=
1
for
(
state
,
id_
),
next_state
in
transitions
.
items
():
if
id_
==
fsm_info
.
alphabet_anything_value
:
continue
symbols
=
id_to_symbol
[
id_
]
for
c
in
symbols
:
byte_
=
None
if
len
(
c
)
==
1
and
ord
(
c
)
<
0x80
:
# ASCII character
byte_
=
ord
(
c
)
elif
len
(
c
)
>
1
:
# FIXME: This logic is due to the leading \x00
# https://github.com/outlines-dev/outlines/pull/930
byte_
=
int
(
symbols
[
0
][
1
:],
16
)
if
byte_
is
not
None
:
outgoings_ct
[
state
]
+=
1
if
outgoings_ct
[
state
]
>
1
:
if
state
in
state_to_jump_forward
:
del
state_to_jump_forward
[
state
]
break
e
=
state_to_jump_forward
.
get
(
state
,
JumpEdge
())
e
.
byte
=
byte_
e
.
byte_next_state
=
next_state
state_to_jump_forward
[
state
]
=
e
return
state_to_jump_forward
self
.
state_to_jump_forward
=
_init_state_to_jump_forward
(
regex_string
)
self
.
state_to_jump_forward
=
init_state_to_jump_forward
(
regex_string
)
def
jump_forward_symbol
(
self
,
state
):
jump_forward_str
=
""
...
...
@@ -164,18 +160,6 @@ class OutlinesJumpForwardMap:
)
class
OutlinesJumpForwardCache
(
BaseToolCache
):
def
__init__
(
self
):
super
().
__init__
()
def
init_value
(
self
,
regex
):
forward_map
=
OutlinesJumpForwardMap
(
regex
)
if
forward_map
.
state_to_jump_forward
:
return
forward_map
else
:
return
None
def
test_main
(
regex_string
):
jump_forward_map
=
OutlinesJumpForwardMap
(
regex_string
)
for
state
,
e
in
jump_forward_map
.
state_to_jump_forward
.
items
():
...
...
python/sglang/srt/constrained/xgrammar_backend.py
View file @
54479d6f
...
...
@@ -15,38 +15,36 @@ limitations under the License.
"""Constrained decoding with xgrammar backend."""
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
typing
import
List
,
Tuple
import
torch
from
xgrammar
import
CachedGrammarCompiler
,
CompiledGrammar
,
GrammarMatcher
try
:
from
xgrammar
import
CachedGrammarCompiler
,
CompiledGrammar
,
GrammarMatcher
import_error
=
None
except
ImportError
as
e
:
import_error
=
e
class
Dummy
:
pass
GrammarMatcher
=
CompiledGrammar
=
CachedGrammarCompiler
=
Dummy
from
sglang.srt.constrained.base_grammar_backend
import
(
BaseGrammarBackend
,
BaseGrammarObject
,
)
MAX_ROLLBACK_TOKENS
=
10
class
XGrammarGrammar
:
class
XGrammarGrammar
(
BaseGrammarObject
)
:
def
__init__
(
self
,
matcher
:
GrammarMatcher
,
vocab_size
:
int
)
->
None
:
def
__init__
(
self
,
matcher
:
GrammarMatcher
,
vocab_size
:
int
,
ctx
:
CompiledGrammar
)
->
None
:
self
.
matcher
=
matcher
self
.
vocab_size
=
vocab_size
self
.
ctx
=
ctx
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Tuple
[
List
[
int
],
str
]:
return
[],
self
.
matcher
.
find_jump_forward_string
()
s
=
self
.
matcher
.
find_jump_forward_string
()
if
s
:
return
[],
s
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
_
,
data
=
helper
...
...
@@ -77,51 +75,40 @@ class XGrammarGrammar:
self
.
matcher
.
get_rejected_tokens_from_bitmask
(
bitmask
,
self
.
vocab_size
)
]
=
1
def
copy
(
self
):
matcher
=
GrammarMatcher
(
self
.
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
mask_vocab_size
=
self
.
vocab_size
,
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
self
.
ctx
)
class
XGrammarGrammarBackend
:
class
XGrammarGrammarBackend
(
BaseGrammarBackend
)
:
def
__init__
(
self
,
tokenizer
,
vocab_size
:
int
,
):
if
import_error
:
raise
import_error
self
.
executor
=
ThreadPoolExecutor
()
self
.
grammar_cache
=
XGrammarCache
(
tokenizer
,
vocab_size
)
self
.
vocab_size
=
vocab_size
def
_query
(
self
,
key
:
Tuple
[
str
,
str
])
->
XGrammarGrammar
:
return
XGrammarGrammar
(
self
.
grammar_cache
.
query
(
key
),
self
.
vocab_size
)
def
query
(
self
,
key
:
Tuple
[
str
,
str
])
->
Future
:
return
self
.
executor
.
submit
(
self
.
_query
,
key
)
def
reset
(
self
):
self
.
grammar_cache
.
reset
()
class
XGrammarCache
:
def
__init__
(
self
,
tokenizer
,
vocab_size
:
int
):
super
().
__init__
()
self
.
grammar_cache
=
CachedGrammarCompiler
(
tokenizer_or_vocab
=
tokenizer
)
self
.
vocab_size
=
vocab_size
def
get_context
(
self
,
key
:
Tuple
[
str
,
str
])
->
Compiled
Grammar
:
def
init_value_impl
(
self
,
key
:
Tuple
[
str
,
str
])
->
XGrammar
Grammar
:
key_type
,
key_string
=
key
if
key_type
==
"json"
:
return
self
.
grammar_cache
.
get_compiled_grammar_for_json_schema
(
key_string
)
ctx
=
self
.
grammar_cache
.
get_compiled_grammar_for_json_schema
(
key_string
)
elif
key_type
==
"regex"
:
raise
ValueError
(
"regex hasn't been supported by xgrammar yet"
)
else
:
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
def
query
(
self
,
key
:
Tuple
[
str
,
str
])
->
GrammarMatcher
:
ctx
=
self
.
get_context
(
key
)
return
GrammarMatcher
(
matcher
=
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
mask_vocab_size
=
self
.
vocab_size
,
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
)
def
reset
(
self
):
self
.
grammar_cache
.
clear
()
python/sglang/srt/managers/schedule_batch.py
View file @
54479d6f
...
...
@@ -37,6 +37,7 @@ import torch
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
...
...
@@ -248,7 +249,7 @@ class Req:
self
.
embedding
=
None
# Constrained decoding
self
.
grammar
=
None
self
.
grammar
:
Optional
[
BaseGrammarObject
]
=
None
# The number of cached tokens, that were already cached in the KV cache
self
.
cached_tokens
=
0
...
...
python/sglang/srt/managers/scheduler.py
View file @
54479d6f
...
...
@@ -244,7 +244,7 @@ class Scheduler:
self
.
grammar_backend
=
OutlinesGrammarBackend
(
self
.
tokenizer
,
whitespace_pattern
s
=
server_args
.
constrained_json_whitespace_pattern
,
whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
allow_jump_forward
=
not
server_args
.
disable_jump_forward
,
)
elif
server_args
.
grammar_backend
==
"xgrammar"
:
...
...
@@ -467,21 +467,6 @@ class Scheduler:
# By default, only return the logprobs for output tokens
req
.
logprob_start_len
=
len
(
recv_req
.
input_ids
)
-
1
# Init grammar cache for this request
if
(
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
):
assert
self
.
grammar_backend
is
not
None
if
req
.
sampling_params
.
json_schema
is
not
None
:
req
.
grammar
=
self
.
grammar_backend
.
query
(
(
"json"
,
req
.
sampling_params
.
json_schema
),
)
elif
req
.
sampling_params
.
regex
is
not
None
:
req
.
grammar
=
self
.
grammar_backend
.
query
(
(
"regex"
,
req
.
sampling_params
.
regex
)
)
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>
self
.
max_req_input_len
:
logger
.
warning
(
...
...
@@ -499,7 +484,24 @@ class Scheduler:
self
.
max_req_len
-
len
(
req
.
origin_input_ids
)
-
1
,
)
if
req
.
grammar
is
not
None
:
# Init grammar cache for this request
add_to_grammar_queue
=
False
if
(
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
):
assert
self
.
grammar_backend
is
not
None
if
req
.
sampling_params
.
json_schema
is
not
None
:
key
=
(
"json"
,
req
.
sampling_params
.
json_schema
)
elif
req
.
sampling_params
.
regex
is
not
None
:
key
=
(
"regex"
,
req
.
sampling_params
.
regex
)
req
.
grammar
=
self
.
grammar_backend
.
get_cached_value
(
key
)
if
not
req
.
grammar
:
req
.
grammar
=
self
.
grammar_backend
.
get_future_value
(
key
)
add_to_grammar_queue
=
True
if
add_to_grammar_queue
:
self
.
grammar_queue
.
append
(
req
)
else
:
self
.
waiting_queue
.
append
(
req
)
...
...
@@ -650,14 +652,7 @@ class Scheduler:
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
# Check if the grammar is ready in the grammar queue
if
self
.
grammar_queue
:
new_grammar_queue
=
[]
for
req
in
self
.
grammar_queue
:
try
:
req
.
grammar
=
req
.
grammar
.
result
(
timeout
=
0.05
)
self
.
waiting_queue
.
append
(
req
)
except
futures
.
_base
.
TimeoutError
:
new_grammar_queue
.
append
(
req
)
self
.
grammar_queue
=
new_grammar_queue
self
.
move_ready_grammar_requests
()
# Handle the cases where prefill is not allowed
if
(
...
...
@@ -1145,6 +1140,30 @@ class Scheduler:
)
)
def
move_ready_grammar_requests
(
self
):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs
=
0
for
req
in
self
.
grammar_queue
:
try
:
req
.
grammar
=
req
.
grammar
.
result
(
timeout
=
0.05
)
num_ready_reqs
+=
1
except
futures
.
_base
.
TimeoutError
:
break
if
self
.
tp_size
>
1
:
# Sync across TP ranks to make sure they have the same number of ready requests
tensor
=
torch
.
tensor
(
num_ready_reqs
,
dtype
=
torch
.
int32
)
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
tp_cpu_group
)
num_ready_reqs_max
=
tensor
.
item
()
for
i
in
range
(
num_ready_reqs
,
num_ready_reqs_max
):
self
.
grammar_queue
[
i
].
grammar
=
self
.
grammar_queue
[
i
].
grammar
.
result
()
num_ready_reqs
=
num_ready_reqs_max
self
.
waiting_queue
.
extend
(
self
.
grammar_queue
[:
num_ready_reqs
])
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
def
flush_cache
(
self
):
"""Flush the memory pool and cache."""
if
len
(
self
.
waiting_queue
)
==
0
and
(
...
...
@@ -1152,9 +1171,8 @@ class Scheduler:
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
if
self
.
grammar_backend
is
not
None
:
if
self
.
grammar_backend
:
self
.
grammar_backend
.
reset
()
# TODO(dark): reset the bnf cache
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment