Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9c745d07
Unverified
Commit
9c745d07
authored
Nov 18, 2024
by
DarkSharpness
Committed by
GitHub
Nov 17, 2024
Browse files
[Performance] Update xgrammar-related constrained decoding (#2056)
parent
ebaa2f31
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
23 deletions
+47
-23
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+11
-1
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+22
-14
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+13
-7
No files found.
python/sglang/srt/constrained/outlines_backend.py
View file @
9c745d07
...
@@ -81,10 +81,20 @@ class OutlinesGrammar(BaseGrammarObject):
...
@@ -81,10 +81,20 @@ class OutlinesGrammar(BaseGrammarObject):
):
):
self
.
state
=
next_state
self
.
state
=
next_state
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
):
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
.
fill_
(
1
)
vocab_mask
.
fill_
(
1
)
vocab_mask
[
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
]
=
0
vocab_mask
[
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
]
=
0
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
logits
.
masked_fill_
(
vocab_mask
,
float
(
"-inf"
))
def
copy
(
self
):
def
copy
(
self
):
return
OutlinesGrammar
(
self
.
guide
,
self
.
jump_forward_map
)
return
OutlinesGrammar
(
self
.
guide
,
self
.
jump_forward_map
)
...
...
python/sglang/srt/constrained/xgrammar_backend.py
View file @
9c745d07
...
@@ -21,7 +21,12 @@ from typing import List, Tuple
...
@@ -21,7 +21,12 @@ from typing import List, Tuple
import
torch
import
torch
try
:
try
:
from
xgrammar
import
CachedGrammarCompiler
,
CompiledGrammar
,
GrammarMatcher
from
xgrammar
import
(
CachedGrammarCompiler
,
CompiledGrammar
,
GrammarMatcher
,
TokenizerInfo
,
)
import_error
=
None
import_error
=
None
except
ImportError
as
e
:
except
ImportError
as
e
:
...
@@ -80,19 +85,23 @@ class XGrammarGrammar(BaseGrammarObject):
...
@@ -80,19 +85,23 @@ class XGrammarGrammar(BaseGrammarObject):
for
i
in
range
(
k
,
len
(
new_output_ids
)):
for
i
in
range
(
k
,
len
(
new_output_ids
)):
assert
self
.
matcher
.
accept_token
(
new_output_ids
[
i
])
assert
self
.
matcher
.
accept_token
(
new_output_ids
[
i
])
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
):
def
allocate_vocab_mask
(
# Note that this bitmask is a bitset, not bool
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
bitmask
=
self
.
matcher
.
get_next_token_bitmask
()
)
->
torch
.
Tensor
:
# Mask the tokens that are not allowed
return
self
.
matcher
.
allocate_token_bitmask
(
vocab_size
,
batch_size
)
vocab_mask
[
self
.
matcher
.
get_rejected_tokens_from_bitmask
(
bitmask
,
self
.
vocab_size
)
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
]
=
1
self
.
matcher
.
fill_next_token_bitmask
(
vocab_mask
,
idx
)
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
GrammarMatcher
.
apply_token_bitmask_inplace
(
logits
,
vocab_mask
)
def
copy
(
self
):
def
copy
(
self
):
matcher
=
GrammarMatcher
(
matcher
=
GrammarMatcher
(
self
.
ctx
,
self
.
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
mask_
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
)
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
self
.
ctx
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
self
.
ctx
)
...
@@ -112,7 +121,8 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -112,7 +121,8 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self
.
grammar_cache
=
None
self
.
grammar_cache
=
None
return
return
self
.
grammar_cache
=
CachedGrammarCompiler
(
tokenizer_or_vocab
=
tokenizer
)
tokenizer_info
=
TokenizerInfo
.
from_huggingface
(
tokenizer
)
self
.
grammar_cache
=
CachedGrammarCompiler
(
tokenizer_info
=
tokenizer_info
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
def
init_value_impl
(
self
,
key
:
Tuple
[
str
,
str
])
->
XGrammarGrammar
:
def
init_value_impl
(
self
,
key
:
Tuple
[
str
,
str
])
->
XGrammarGrammar
:
...
@@ -122,9 +132,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -122,9 +132,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
key_type
,
key_string
=
key
key_type
,
key_string
=
key
if
key_type
==
"json"
:
if
key_type
==
"json"
:
try
:
try
:
ctx
=
self
.
grammar_cache
.
get_compiled_grammar_for_json_schema
(
ctx
=
self
.
grammar_cache
.
compile_json_schema_grammar
(
schema
=
key_string
)
key_string
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logging
.
warning
(
logging
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
...
@@ -141,7 +149,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -141,7 +149,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
matcher
=
GrammarMatcher
(
matcher
=
GrammarMatcher
(
ctx
,
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
mask_
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
)
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
9c745d07
...
@@ -645,7 +645,7 @@ class ModelRunner:
...
@@ -645,7 +645,7 @@ class ModelRunner:
# Apply regex vocab_mask
# Apply regex vocab_mask
if
sampling_info
.
vocab_mask
is
not
None
:
if
sampling_info
.
vocab_mask
is
not
None
:
logits
=
logits
.
masked_fill
(
sampling_info
.
vocab_mask
,
float
(
"-inf"
)
)
sampling_info
.
apply_mask
(
logits
=
logits
,
vocab_mask
=
sampling_info
.
vocab_mask
)
return
logits
return
logits
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
9c745d07
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclasses
import
dataclasses
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
import
torch
import
torch
...
@@ -29,7 +29,7 @@ class SamplingBatchInfo:
...
@@ -29,7 +29,7 @@ class SamplingBatchInfo:
vocab_size
:
int
vocab_size
:
int
logit_bias
:
torch
.
Tensor
=
None
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
apply_mask
:
Optional
[
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
None
]]
=
None
grammars
:
Optional
[
List
]
=
None
grammars
:
Optional
[
List
]
=
None
# Penalizer
# Penalizer
...
@@ -135,17 +135,23 @@ class SamplingBatchInfo:
...
@@ -135,17 +135,23 @@ class SamplingBatchInfo:
def
update_regex_vocab_mask
(
self
):
def
update_regex_vocab_mask
(
self
):
if
not
self
.
grammars
or
not
any
(
grammar
for
grammar
in
self
.
grammars
):
if
not
self
.
grammars
or
not
any
(
grammar
for
grammar
in
self
.
grammars
):
self
.
vocab_mask
=
None
self
.
vocab_mask
=
None
self
.
apply_mask
=
None
return
return
self
.
vocab_mask
=
torch
.
zeros
(
# find a grammar from the list
len
(
self
.
temperatures
),
grammar
=
next
(
grammar
for
grammar
in
self
.
grammars
if
grammar
is
not
None
)
self
.
vocab_size
,
dtype
=
torch
.
bool
,
# maybe we can reuse the existing mask?
self
.
vocab_mask
=
grammar
.
allocate_vocab_mask
(
vocab_size
=
self
.
vocab_size
,
batch_size
=
len
(
self
.
temperatures
),
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
apply_mask
=
type
(
grammar
).
apply_vocab_mask
# force to use static method
for
i
,
grammar
in
enumerate
(
self
.
grammars
):
for
i
,
grammar
in
enumerate
(
self
.
grammars
):
if
grammar
is
not
None
:
if
grammar
is
not
None
:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
[
i
]
)
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
if
self
.
penalizer_orchestrator
:
if
self
.
penalizer_orchestrator
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment