Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
19120f71
Unverified
Commit
19120f71
authored
Mar 04, 2025
by
DarkSharpness
Committed by
GitHub
Mar 04, 2025
Browse files
[Fix & Style] Refactor the grammar backend to reduce human errors and improve readability (#4030)
parent
2415ec38
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
199 additions
and
95 deletions
+199
-95
python/sglang/srt/constrained/base_grammar_backend.py
python/sglang/srt/constrained/base_grammar_backend.py
+110
-14
python/sglang/srt/constrained/llguidance_backend.py
python/sglang/srt/constrained/llguidance_backend.py
+21
-16
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+20
-18
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+48
-47
No files found.
python/sglang/srt/constrained/base_grammar_backend.py
View file @
19120f71
...
...
@@ -13,31 +13,130 @@
# ==============================================================================
"""The baseclass of a backend for grammar-guided constrained decoding."""
import
logging
from
abc
import
ABC
,
abstractmethod
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
threading
import
Event
,
Lock
from
typing
import
Any
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
sglang.srt.server_args
import
ServerArgs
logger
=
logging
.
getLogger
(
__name__
)
class
BaseGrammarObject
(
ABC
):
@
abstractmethod
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
"""
Try to jump forward in the grammar.
Returns:
A jump forward helper which may be used in `jump_forward_str_state`.
None if the jump forward is not possible.
"""
raise
NotImplementedError
@
abstractmethod
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
"""
Jump forward for the grammar.
Returns:
A tuple of the jump forward string and the next state of the grammar
(which can be used in `jump_and_retokenize` if needed).
"""
raise
NotImplementedError
@
abstractmethod
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
)
->
None
:
"""
Jump forward occurs, and update the grammar state if needed.
"""
raise
NotImplementedError
@
abstractmethod
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
abstractmethod
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
copy
(
self
)
->
"BaseGrammarObject"
:
raise
NotImplementedError
@
dataclass
class
CacheEntry
:
value
:
Any
value
:
Optional
[
BaseGrammarObject
]
event
:
Event
class
BaseGrammarObject
:
pass
class
BaseGrammarBackend
:
class
BaseGrammarBackend
(
ABC
):
def
__init__
(
self
):
self
.
executor
=
ThreadPoolExecutor
()
self
.
cache
=
{}
self
.
cache
:
Dict
[
Tuple
[
str
,
str
],
CacheEntry
]
=
{}
self
.
cache_lock
=
Lock
()
def
init_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
BaseGrammarObject
:
def
_not_supported
(
self
,
key_type
:
str
,
key_string
:
str
)
->
None
:
logger
.
warning
(
f
"Skip unsupported
{
key_type
}
:
{
key_type
}
=
{
key_string
}
"
)
def
dispatch_fallback
(
self
,
key_type
:
str
,
key_string
:
str
)
->
Optional
[
BaseGrammarObject
]:
"""
This function should not be reached in any case.
"""
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
=
{
key_string
}
"
)
@
abstractmethod
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
BaseGrammarObject
]:
return
self
.
_not_supported
(
"json"
,
key_string
)
@
abstractmethod
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
BaseGrammarObject
]:
return
self
.
_not_supported
(
"regex"
,
key_string
)
@
abstractmethod
def
dispatch_ebnf
(
self
,
key_string
:
str
)
->
Optional
[
BaseGrammarObject
]:
return
self
.
_not_supported
(
"ebnf"
,
key_string
)
@
abstractmethod
def
dispatch_structural_tag
(
self
,
key_string
:
str
)
->
Optional
[
BaseGrammarObject
]:
return
self
.
_not_supported
(
"structural_tag"
,
key_string
)
def
_init_value_dispatch
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
BaseGrammarObject
]:
key_type
,
key_string
=
key
if
key_type
==
"json"
:
return
self
.
dispatch_json
(
key_string
)
elif
key_type
==
"regex"
:
return
self
.
dispatch_regex
(
key_string
)
elif
key_type
==
"ebnf"
:
return
self
.
dispatch_ebnf
(
key_string
)
elif
key_type
==
"structural_tag"
:
return
self
.
dispatch_structural_tag
(
key_string
)
else
:
return
self
.
dispatch_fallback
(
key_type
,
key_string
)
def
_init_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
BaseGrammarObject
]:
with
self
.
cache_lock
:
if
key
in
self
.
cache
:
cache_hit
=
True
...
...
@@ -50,13 +149,10 @@ class BaseGrammarBackend:
if
cache_hit
:
entry
.
event
.
wait
()
else
:
entry
.
value
=
self
.
init_value_
impl
(
key
)
entry
.
value
=
self
.
_
init_value_
dispatch
(
key
)
entry
.
event
.
set
()
return
entry
.
value
.
copy
()
if
entry
.
value
else
None
def
init_value_impl
(
self
,
key
:
Tuple
[
str
,
str
])
->
BaseGrammarObject
:
raise
NotImplementedError
()
def
get_cached_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
BaseGrammarObject
]:
with
self
.
cache_lock
:
entry
=
self
.
cache
.
get
(
key
)
...
...
@@ -66,7 +162,7 @@ class BaseGrammarBackend:
return
val
.
copy
()
if
val
else
None
def
get_future_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Future
:
return
self
.
executor
.
submit
(
self
.
init_value
,
key
)
return
self
.
executor
.
submit
(
self
.
_
init_value
,
key
)
def
reset
(
self
):
with
self
.
cache_lock
:
...
...
python/sglang/srt/constrained/llguidance_backend.py
View file @
19120f71
...
...
@@ -48,7 +48,7 @@ class GuidanceGrammar(BaseGrammarObject):
self
.
finished
=
False
self
.
bitmask
=
None
def
try_jump_forward
(
self
,
tokenizer
)
->
Tuple
[
List
[
int
],
str
]:
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]
]
:
if
len
(
self
.
pending_ff_tokens
)
>
0
:
s
=
self
.
llguidance_tokenizer
.
decode_str
(
self
.
pending_ff_tokens
)
ff_tokens
=
self
.
pending_ff_tokens
...
...
@@ -125,22 +125,27 @@ class GuidanceBackend(BaseGrammarBackend):
)
self
.
llguidance_tokenizer
=
llguidance
.
hf
.
from_tokenizer
(
self
.
tokenizer
,
None
)
def
init_value_impl
(
self
,
key
:
Tuple
[
str
,
str
])
->
GuidanceGrammar
:
mode
,
value
=
key
if
mode
==
"json"
:
json_schema
=
value
compiler
=
llguidance
.
JsonCompiler
(
whitespace_flexible
=
self
.
whitespace_flexible
)
serialized_grammar
=
compiler
.
compile
(
json_schema
)
elif
mode
==
"regex"
:
compiler
=
llguidance
.
RegexCompiler
()
serialized_grammar
=
compiler
.
compile
(
regex
=
value
)
elif
mode
==
"ebnf"
:
compiler
=
llguidance
.
LarkCompiler
()
serialized_grammar
=
compiler
.
compile
(
any_to_lark
(
value
))
def
_from_serialized
(
self
,
serialized_grammar
)
->
GuidanceGrammar
:
return
GuidanceGrammar
(
llguidance_tokenizer
=
self
.
llguidance_tokenizer
,
serialized_grammar
=
serialized_grammar
,
)
def
dispatch_json
(
self
,
key_string
:
str
)
->
GuidanceGrammar
:
json_schema
=
key_string
compiler
=
llguidance
.
JsonCompiler
(
whitespace_flexible
=
self
.
whitespace_flexible
)
serialized_grammar
=
compiler
.
compile
(
json_schema
)
return
self
.
_from_serialized
(
serialized_grammar
)
def
dispatch_regex
(
self
,
key_string
:
str
)
->
GuidanceGrammar
:
compiler
=
llguidance
.
RegexCompiler
()
serialized_grammar
=
compiler
.
compile
(
regex
=
key_string
)
return
self
.
_from_serialized
(
serialized_grammar
)
def
dispatch_ebnf
(
self
,
key_string
:
str
)
->
GuidanceGrammar
:
compiler
=
llguidance
.
LarkCompiler
()
serialized_grammar
=
compiler
.
compile
(
any_to_lark
(
key_string
))
return
self
.
_from_serialized
(
serialized_grammar
)
def
dispatch_structural_tag
(
self
,
key_string
:
str
):
return
super
().
dispatch_structural_tag
(
key_string
)
python/sglang/srt/constrained/outlines_backend.py
View file @
19120f71
...
...
@@ -141,24 +141,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
)
self
.
whitespace_pattern
=
whitespace_pattern
def
init_value_impl
(
self
,
key
:
Tuple
[
str
,
str
])
->
OutlinesGrammar
:
key_type
,
key_string
=
key
if
key_type
==
"json"
:
try
:
regex
=
build_regex_from_object
(
key_string
,
whitespace_pattern
=
self
.
whitespace_pattern
,
)
except
(
NotImplementedError
,
json
.
decoder
.
JSONDecodeError
)
as
e
:
logger
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
return
None
elif
key_type
==
"regex"
:
regex
=
key_string
else
:
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
def
_compile_regex
(
self
,
regex
:
str
)
->
Optional
[
OutlinesGrammar
]:
try
:
if
hasattr
(
RegexGuide
,
"from_regex"
):
# outlines >= 0.1.1
...
...
@@ -173,6 +156,25 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
jump_forward_map
=
None
return
OutlinesGrammar
(
guide
,
jump_forward_map
)
def
dispatch_ebnf
(
self
,
key_string
:
str
):
return
super
().
dispatch_ebnf
(
key_string
)
def
dispatch_structural_tag
(
self
,
key_string
:
str
):
return
super
().
dispatch_structural_tag
(
key_string
)
def
dispatch_json
(
self
,
key_string
:
str
):
try
:
regex
=
build_regex_from_object
(
key_string
,
whitespace_pattern
=
self
.
whitespace_pattern
,
)
except
(
NotImplementedError
,
json
.
decoder
.
JSONDecodeError
)
as
e
:
logger
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
return
self
.
_compile_regex
(
regex
)
def
dispatch_regex
(
self
,
key_string
:
str
):
return
self
.
_compile_regex
(
key_string
)
def
build_regex_from_object
(
object
:
Union
[
str
,
BaseModel
,
Dict
],
whitespace_pattern
:
Optional
[
str
]
=
None
...
...
python/sglang/srt/constrained/xgrammar_backend.py
View file @
19120f71
...
...
@@ -57,7 +57,7 @@ class XGrammarGrammar(BaseGrammarObject):
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Tuple
[
List
[
int
],
str
]:
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]
]
:
s
=
self
.
matcher
.
find_jump_forward_string
()
if
s
:
return
[],
s
...
...
@@ -128,55 +128,56 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self
.
vocab_size
=
vocab_size
self
.
override_stop_tokens
=
override_stop_tokens
def
init_value_impl
(
self
,
key
:
Tuple
[
str
,
str
])
->
XGrammarGrammar
:
key_type
,
key_string
=
key
if
key_type
==
"json"
:
try
:
if
key_string
==
"$$ANY$$"
:
ctx
=
self
.
grammar_compiler
.
compile_builtin_json_grammar
()
else
:
ctx
=
self
.
grammar_compiler
.
compile_json_schema
(
schema
=
key_string
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
return
None
elif
key_type
==
"ebnf"
:
try
:
ctx
=
self
.
grammar_compiler
.
compile_grammar
(
key_string
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid ebnf: ebnf=
{
key_string
}
,
{
e
=
}
"
)
return
None
elif
key_type
==
"regex"
:
try
:
ctx
=
self
.
grammar_compiler
.
compile_regex
(
key_string
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
return
None
elif
key_type
==
"structural_tag"
:
try
:
structural_tag
=
json
.
loads
(
key_string
)
tags
=
[
StructuralTagItem
(
begin
=
structure
[
"begin"
],
schema
=
json
.
dumps
(
structure
[
"schema"
]),
end
=
structure
[
"end"
],
)
for
structure
in
structural_tag
[
"structures"
]
]
ctx
=
self
.
grammar_compiler
.
compile_structural_tag
(
tags
,
structural_tag
[
"triggers"
]
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
return
None
else
:
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
def
_from_context
(
self
,
ctx
:
CompiledGrammar
)
->
XGrammarGrammar
:
matcher
=
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
,
self
.
override_stop_tokens
)
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
if
key_string
==
"$$ANY$$"
:
ctx
=
self
.
grammar_compiler
.
compile_builtin_json_grammar
()
else
:
ctx
=
self
.
grammar_compiler
.
compile_json_schema
(
schema
=
key_string
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_context
(
ctx
)
def
dispatch_ebnf
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
ctx
=
self
.
grammar_compiler
.
compile_grammar
(
key_string
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid ebnf: ebnf=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_context
(
ctx
)
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
ctx
=
self
.
grammar_compiler
.
compile_regex
(
key_string
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_context
(
ctx
)
def
dispatch_structural_tag
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
structural_tag
=
json
.
loads
(
key_string
)
tags
=
[
StructuralTagItem
(
begin
=
structure
[
"begin"
],
schema
=
json
.
dumps
(
structure
[
"schema"
]),
end
=
structure
[
"end"
],
)
for
structure
in
structural_tag
[
"structures"
]
]
ctx
=
self
.
grammar_compiler
.
compile_structural_tag
(
tags
,
structural_tag
[
"triggers"
]
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_context
(
ctx
)
def
reset
(
self
):
if
self
.
grammar_compiler
:
self
.
grammar_compiler
.
clear_cache
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment