Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bca832c7
Unverified
Commit
bca832c7
authored
Apr 21, 2025
by
JieXin Liang
Committed by
GitHub
Apr 20, 2025
Browse files
[Fix] fix outlines and xgrammar (#4947)
parent
d9dd5298
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
10 deletions
+185
-10
python/sglang/srt/constrained/outlines_jump_forward.py
python/sglang/srt/constrained/outlines_jump_forward.py
+14
-1
python/sglang/srt/constrained/triton_ops/bitmask_ops.py
python/sglang/srt/constrained/triton_ops/bitmask_ops.py
+141
-0
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+26
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-0
No files found.
python/sglang/srt/constrained/outlines_jump_forward.py
View file @
bca832c7
...
...
@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
import
dataclasses
import
logging
from
collections
import
defaultdict
from
typing
import
Optional
import
interegular
from
interegular
import
InvalidSyntax
from
outlines.caching
import
cache
as
disk_cache
from
outlines.caching
import
cache
from
sglang.srt.utils
import
get_bool_env_var
try
:
# outlines >= 0.1.0
...
...
@@ -34,6 +37,9 @@ except ImportError:
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
DISABLE_DISK_CACHE
=
get_bool_env_var
(
"SGLANG_DISABLE_OUTLINES_DISK_CACHE"
,
"true"
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -45,6 +51,13 @@ class JumpEdge:
byte_next_state
:
int
=
None
def
disk_cache
(
expire
:
Optional
[
float
]
=
None
,
typed
=
False
,
ignore
=
()):
if
not
DISABLE_DISK_CACHE
:
return
cache
(
expire
,
typed
,
ignore
)
else
:
return
lambda
fn
:
None
@
disk_cache
()
def
init_state_to_jump_forward
(
regex_string
):
try
:
...
...
python/sglang/srt/constrained/triton_ops/bitmask_ops.py
0 → 100644
View file @
bca832c7
# Adapt from
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
from
typing
import
List
,
Optional
,
Union
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.utils
import
get_device_core_count
@
triton
.
jit
def
apply_token_bitmask_inplace_kernel
(
logits_ptr
,
bitmask_ptr
,
indices_ptr
,
num_rows
,
vocab_size
,
logits_strides
,
bitmask_strides
,
NUM_SMS
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
the masked logits will be set to -inf.
Parameters
----------
logits_ptr : tl.tensor
Pointer to the logits tensor to apply the bitmask to.
bitmask_ptr : tl.tensor
Pointer to the bitmask tensor to apply.
indices_ptr : Optional[tl.tensor]
Optional pointer to indices tensor specifying which rows to apply the mask to.
num_rows : int
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
vocab_size : int
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
logits_strides : int
Stride between rows in the logits tensor.
bitmask_strides : int
Stride between rows in the bitmask tensor.
NUM_SMS : int
Number of streaming multiprocessors to use.
BLOCK_SIZE : int
Size of processing blocks.
"""
pid
=
tl
.
program_id
(
0
)
num_blocks
=
tl
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
for
work_id
in
tl
.
range
(
pid
,
num_rows
*
num_blocks
,
NUM_SMS
):
row_id
=
work_id
//
num_blocks
block_offset
=
(
work_id
%
num_blocks
)
*
BLOCK_SIZE
batch_id
=
row_id
if
indices_ptr
is
None
else
tl
.
load
(
indices_ptr
+
row_id
)
offsets
=
block_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
bitmask_offsets
=
block_offset
//
32
+
tl
.
arange
(
0
,
BLOCK_SIZE
//
32
)
vocab_mask
=
offsets
<
vocab_size
packed_bitmask_mask
=
bitmask_offsets
<
bitmask_strides
packed_bitmask
=
tl
.
load
(
bitmask_ptr
+
batch_id
*
bitmask_strides
+
bitmask_offsets
,
packed_bitmask_mask
,
)
bitmask
=
((
packed_bitmask
[:,
None
]
>>
(
tl
.
arange
(
0
,
32
)[
None
,
:]))
&
1
)
==
0
bitmask
=
bitmask
.
reshape
(
BLOCK_SIZE
)
tl
.
store
(
logits_ptr
+
batch_id
*
logits_strides
+
offsets
,
-
float
(
"inf"
),
vocab_mask
&
bitmask
,
)
def
apply_token_bitmask_inplace_triton
(
logits
:
torch
.
Tensor
,
bitmask
:
torch
.
Tensor
,
indices
:
Optional
[
Union
[
List
[
int
],
torch
.
Tensor
]]
=
None
,
):
NUM_SMS
=
get_device_core_count
()
BLOCK_SIZE
=
4096
BITS_PER_BLOCK
=
32
# Check input dtype
assert
bitmask
.
dtype
==
torch
.
int32
,
"bitmask must be of type int32"
# Check input tensor shapes.
logits_shape
=
logits
.
shape
bitmask_shape
=
bitmask
.
shape
if
logits
.
ndim
==
1
:
logits_shape
=
(
1
,
logits_shape
[
0
])
if
bitmask
.
ndim
==
1
:
bitmask_shape
=
(
1
,
bitmask_shape
[
0
])
required_bitmask_width
=
(
logits_shape
[
1
]
+
BITS_PER_BLOCK
-
1
)
//
BITS_PER_BLOCK
assert
required_bitmask_width
>=
bitmask_shape
[
1
],
(
f
"Bitmask width too large: allow at most
{
required_bitmask_width
}
int32s for "
f
"logits' width
{
logits_shape
[
1
]
}
, but got
{
bitmask_shape
[
1
]
}
"
)
vocab_size
=
min
(
logits_shape
[
1
],
bitmask_shape
[
1
]
*
BITS_PER_BLOCK
)
num_rows
=
None
if
isinstance
(
indices
,
list
)
or
isinstance
(
indices
,
torch
.
Tensor
):
indices
=
torch
.
tensor
(
indices
,
dtype
=
torch
.
int32
,
device
=
logits
.
device
)
num_rows
=
indices
.
shape
[
0
]
else
:
assert
(
logits_shape
[
0
]
==
bitmask_shape
[
0
]
),
f
"batch size mismatch: logits
{
logits_shape
[
0
]
}
vs bitmask
{
bitmask_shape
[
0
]
}
"
num_rows
=
logits_shape
[
0
]
if
NUM_SMS
>
0
:
grid
=
(
NUM_SMS
,)
else
:
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
grid
=
(
num_rows
*
num_blocks
,)
NUM_SMS
=
triton
.
next_power_of_2
(
grid
[
0
])
apply_token_bitmask_inplace_kernel
[
grid
](
logits
,
bitmask
,
indices
,
num_rows
,
vocab_size
,
logits_shape
[
1
],
bitmask_shape
[
1
],
NUM_SMS
,
BLOCK_SIZE
,
num_warps
=
BLOCK_SIZE
//
32
//
(
16
//
logits
.
element_size
()),
num_stages
=
3
,
)
python/sglang/srt/constrained/xgrammar_backend.py
View file @
bca832c7
...
...
@@ -25,13 +25,16 @@ from xgrammar import (
StructuralTagItem
,
TokenizerInfo
,
allocate_token_bitmask
,
apply_token_bitmask_inplace
,
)
from
sglang.srt.constrained.base_grammar_backend
import
(
BaseGrammarBackend
,
BaseGrammarObject
,
)
from
sglang.srt.constrained.triton_ops.bitmask_ops
import
(
apply_token_bitmask_inplace_triton
,
)
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -55,6 +58,18 @@ class XGrammarGrammar(BaseGrammarObject):
self
.
override_stop_tokens
=
override_stop_tokens
self
.
finished
=
False
# Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
# class init site to avoid re-initializing CUDA in forked subprocess.
from
xgrammar.kernels
import
apply_token_bitmask_inplace_kernels
self
.
use_token_bitmask_triton
=
get_bool_env_var
(
"SGLANG_TOKEN_BITMASK_TRITON"
,
"false"
)
self
.
apply_vocab_mask_cuda
=
apply_token_bitmask_inplace_kernels
.
get
(
"cuda"
,
None
)
self
.
apply_vocab_mask_cpu
=
apply_token_bitmask_inplace_kernels
.
get
(
"cpu"
,
None
)
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
...
...
@@ -97,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject):
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
.
to
(
device
,
non_blocking
=
True
)
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
apply_token_bitmask_inplace
(
logits
,
vocab_mask
)
def
apply_vocab_mask
(
self
,
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
if
(
not
self
.
use_token_bitmask_triton
and
logits
.
device
.
type
==
"cuda"
and
self
.
apply_vocab_mask_cuda
):
return
self
.
apply_vocab_mask_cuda
(
logits
,
vocab_mask
)
if
logits
.
device
.
type
==
"cpu"
and
self
.
apply_vocab_mask_cpu
:
return
self
.
apply_vocab_mask_cpu
(
logits
,
vocab_mask
)
apply_token_bitmask_inplace_triton
(
logits
,
vocab_mask
)
def
copy
(
self
):
matcher
=
GrammarMatcher
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
bca832c7
...
...
@@ -137,11 +137,6 @@ class ModelRunner:
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
if
server_args
.
disable_outlines_disk_cache
:
from
outlines.caching
import
disable_cache
disable_cache
()
# Global vars
global_server_args_dict
.
update
(
{
...
...
python/sglang/srt/server_args.py
View file @
bca832c7
...
...
@@ -392,6 +392,10 @@ class ServerArgs:
os
.
environ
[
"SGLANG_ENABLE_TORCH_COMPILE"
]
=
(
"1"
if
self
.
enable_torch_compile
else
"0"
)
# Set env var before grammar backends init
os
.
environ
[
"SGLANG_DISABLE_OUTLINES_DISK_CACHE"
]
=
(
"1"
if
self
.
disable_outlines_disk_cache
else
"0"
)
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment