Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bca832c7
Unverified
Commit
bca832c7
authored
Apr 21, 2025
by
JieXin Liang
Committed by
GitHub
Apr 20, 2025
Browse files
[Fix] fix outlines and xgrammar (#4947)
parent
d9dd5298
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
10 deletions
+185
-10
python/sglang/srt/constrained/outlines_jump_forward.py
python/sglang/srt/constrained/outlines_jump_forward.py
+14
-1
python/sglang/srt/constrained/triton_ops/bitmask_ops.py
python/sglang/srt/constrained/triton_ops/bitmask_ops.py
+141
-0
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+26
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-0
No files found.
python/sglang/srt/constrained/outlines_jump_forward.py
View file @
bca832c7
...
@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
...
@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
import
dataclasses
import
dataclasses
import
logging
import
logging
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Optional
import
interegular
import
interegular
from
interegular
import
InvalidSyntax
from
interegular
import
InvalidSyntax
from
outlines.caching
import
cache
as
disk_cache
from
outlines.caching
import
cache
from
sglang.srt.utils
import
get_bool_env_var
try
:
try
:
# outlines >= 0.1.0
# outlines >= 0.1.0
...
@@ -34,6 +37,9 @@ except ImportError:
...
@@ -34,6 +37,9 @@ except ImportError:
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
DISABLE_DISK_CACHE
=
get_bool_env_var
(
"SGLANG_DISABLE_OUTLINES_DISK_CACHE"
,
"true"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -45,6 +51,13 @@ class JumpEdge:
...
@@ -45,6 +51,13 @@ class JumpEdge:
byte_next_state
:
int
=
None
byte_next_state
:
int
=
None
def
disk_cache
(
expire
:
Optional
[
float
]
=
None
,
typed
=
False
,
ignore
=
()):
if
not
DISABLE_DISK_CACHE
:
return
cache
(
expire
,
typed
,
ignore
)
else
:
return
lambda
fn
:
None
@
disk_cache
()
@
disk_cache
()
def
init_state_to_jump_forward
(
regex_string
):
def
init_state_to_jump_forward
(
regex_string
):
try
:
try
:
...
...
python/sglang/srt/constrained/triton_ops/bitmask_ops.py
0 → 100644
View file @
bca832c7
# Adapt from
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
from
typing
import
List
,
Optional
,
Union
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.utils
import
get_device_core_count
@
triton
.
jit
def
apply_token_bitmask_inplace_kernel
(
logits_ptr
,
bitmask_ptr
,
indices_ptr
,
num_rows
,
vocab_size
,
logits_strides
,
bitmask_strides
,
NUM_SMS
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
the masked logits will be set to -inf.
Parameters
----------
logits_ptr : tl.tensor
Pointer to the logits tensor to apply the bitmask to.
bitmask_ptr : tl.tensor
Pointer to the bitmask tensor to apply.
indices_ptr : Optional[tl.tensor]
Optional pointer to indices tensor specifying which rows to apply the mask to.
num_rows : int
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
vocab_size : int
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
logits_strides : int
Stride between rows in the logits tensor.
bitmask_strides : int
Stride between rows in the bitmask tensor.
NUM_SMS : int
Number of streaming multiprocessors to use.
BLOCK_SIZE : int
Size of processing blocks.
"""
pid
=
tl
.
program_id
(
0
)
num_blocks
=
tl
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
for
work_id
in
tl
.
range
(
pid
,
num_rows
*
num_blocks
,
NUM_SMS
):
row_id
=
work_id
//
num_blocks
block_offset
=
(
work_id
%
num_blocks
)
*
BLOCK_SIZE
batch_id
=
row_id
if
indices_ptr
is
None
else
tl
.
load
(
indices_ptr
+
row_id
)
offsets
=
block_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
bitmask_offsets
=
block_offset
//
32
+
tl
.
arange
(
0
,
BLOCK_SIZE
//
32
)
vocab_mask
=
offsets
<
vocab_size
packed_bitmask_mask
=
bitmask_offsets
<
bitmask_strides
packed_bitmask
=
tl
.
load
(
bitmask_ptr
+
batch_id
*
bitmask_strides
+
bitmask_offsets
,
packed_bitmask_mask
,
)
bitmask
=
((
packed_bitmask
[:,
None
]
>>
(
tl
.
arange
(
0
,
32
)[
None
,
:]))
&
1
)
==
0
bitmask
=
bitmask
.
reshape
(
BLOCK_SIZE
)
tl
.
store
(
logits_ptr
+
batch_id
*
logits_strides
+
offsets
,
-
float
(
"inf"
),
vocab_mask
&
bitmask
,
)
def
apply_token_bitmask_inplace_triton
(
logits
:
torch
.
Tensor
,
bitmask
:
torch
.
Tensor
,
indices
:
Optional
[
Union
[
List
[
int
],
torch
.
Tensor
]]
=
None
,
):
NUM_SMS
=
get_device_core_count
()
BLOCK_SIZE
=
4096
BITS_PER_BLOCK
=
32
# Check input dtype
assert
bitmask
.
dtype
==
torch
.
int32
,
"bitmask must be of type int32"
# Check input tensor shapes.
logits_shape
=
logits
.
shape
bitmask_shape
=
bitmask
.
shape
if
logits
.
ndim
==
1
:
logits_shape
=
(
1
,
logits_shape
[
0
])
if
bitmask
.
ndim
==
1
:
bitmask_shape
=
(
1
,
bitmask_shape
[
0
])
required_bitmask_width
=
(
logits_shape
[
1
]
+
BITS_PER_BLOCK
-
1
)
//
BITS_PER_BLOCK
assert
required_bitmask_width
>=
bitmask_shape
[
1
],
(
f
"Bitmask width too large: allow at most
{
required_bitmask_width
}
int32s for "
f
"logits' width
{
logits_shape
[
1
]
}
, but got
{
bitmask_shape
[
1
]
}
"
)
vocab_size
=
min
(
logits_shape
[
1
],
bitmask_shape
[
1
]
*
BITS_PER_BLOCK
)
num_rows
=
None
if
isinstance
(
indices
,
list
)
or
isinstance
(
indices
,
torch
.
Tensor
):
indices
=
torch
.
tensor
(
indices
,
dtype
=
torch
.
int32
,
device
=
logits
.
device
)
num_rows
=
indices
.
shape
[
0
]
else
:
assert
(
logits_shape
[
0
]
==
bitmask_shape
[
0
]
),
f
"batch size mismatch: logits
{
logits_shape
[
0
]
}
vs bitmask
{
bitmask_shape
[
0
]
}
"
num_rows
=
logits_shape
[
0
]
if
NUM_SMS
>
0
:
grid
=
(
NUM_SMS
,)
else
:
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
grid
=
(
num_rows
*
num_blocks
,)
NUM_SMS
=
triton
.
next_power_of_2
(
grid
[
0
])
apply_token_bitmask_inplace_kernel
[
grid
](
logits
,
bitmask
,
indices
,
num_rows
,
vocab_size
,
logits_shape
[
1
],
bitmask_shape
[
1
],
NUM_SMS
,
BLOCK_SIZE
,
num_warps
=
BLOCK_SIZE
//
32
//
(
16
//
logits
.
element_size
()),
num_stages
=
3
,
)
python/sglang/srt/constrained/xgrammar_backend.py
View file @
bca832c7
...
@@ -25,13 +25,16 @@ from xgrammar import (
...
@@ -25,13 +25,16 @@ from xgrammar import (
StructuralTagItem
,
StructuralTagItem
,
TokenizerInfo
,
TokenizerInfo
,
allocate_token_bitmask
,
allocate_token_bitmask
,
apply_token_bitmask_inplace
,
)
)
from
sglang.srt.constrained.base_grammar_backend
import
(
from
sglang.srt.constrained.base_grammar_backend
import
(
BaseGrammarBackend
,
BaseGrammarBackend
,
BaseGrammarObject
,
BaseGrammarObject
,
)
)
from
sglang.srt.constrained.triton_ops.bitmask_ops
import
(
apply_token_bitmask_inplace_triton
,
)
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -55,6 +58,18 @@ class XGrammarGrammar(BaseGrammarObject):
...
@@ -55,6 +58,18 @@ class XGrammarGrammar(BaseGrammarObject):
self
.
override_stop_tokens
=
override_stop_tokens
self
.
override_stop_tokens
=
override_stop_tokens
self
.
finished
=
False
self
.
finished
=
False
# Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
# class init site to avoid re-initializing CUDA in forked subprocess.
from
xgrammar.kernels
import
apply_token_bitmask_inplace_kernels
self
.
use_token_bitmask_triton
=
get_bool_env_var
(
"SGLANG_TOKEN_BITMASK_TRITON"
,
"false"
)
self
.
apply_vocab_mask_cuda
=
apply_token_bitmask_inplace_kernels
.
get
(
"cuda"
,
None
)
self
.
apply_vocab_mask_cpu
=
apply_token_bitmask_inplace_kernels
.
get
(
"cpu"
,
None
)
def
accept_token
(
self
,
token
:
int
):
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
assert
self
.
matcher
.
accept_token
(
token
)
...
@@ -97,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject):
...
@@ -97,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject):
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
.
to
(
device
,
non_blocking
=
True
)
return
vocab_mask
.
to
(
device
,
non_blocking
=
True
)
@
staticmethod
def
apply_vocab_mask
(
self
,
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
if
(
apply_token_bitmask_inplace
(
logits
,
vocab_mask
)
not
self
.
use_token_bitmask_triton
and
logits
.
device
.
type
==
"cuda"
and
self
.
apply_vocab_mask_cuda
):
return
self
.
apply_vocab_mask_cuda
(
logits
,
vocab_mask
)
if
logits
.
device
.
type
==
"cpu"
and
self
.
apply_vocab_mask_cpu
:
return
self
.
apply_vocab_mask_cpu
(
logits
,
vocab_mask
)
apply_token_bitmask_inplace_triton
(
logits
,
vocab_mask
)
def
copy
(
self
):
def
copy
(
self
):
matcher
=
GrammarMatcher
(
matcher
=
GrammarMatcher
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
bca832c7
...
@@ -137,11 +137,6 @@ class ModelRunner:
...
@@ -137,11 +137,6 @@ class ModelRunner:
if
server_args
.
show_time_cost
:
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
enable_show_time_cost
()
if
server_args
.
disable_outlines_disk_cache
:
from
outlines.caching
import
disable_cache
disable_cache
()
# Global vars
# Global vars
global_server_args_dict
.
update
(
global_server_args_dict
.
update
(
{
{
...
...
python/sglang/srt/server_args.py
View file @
bca832c7
...
@@ -392,6 +392,10 @@ class ServerArgs:
...
@@ -392,6 +392,10 @@ class ServerArgs:
os
.
environ
[
"SGLANG_ENABLE_TORCH_COMPILE"
]
=
(
os
.
environ
[
"SGLANG_ENABLE_TORCH_COMPILE"
]
=
(
"1"
if
self
.
enable_torch_compile
else
"0"
"1"
if
self
.
enable_torch_compile
else
"0"
)
)
# Set env var before grammar backends init
os
.
environ
[
"SGLANG_DISABLE_OUTLINES_DISK_CACHE"
]
=
(
"1"
if
self
.
disable_outlines_disk_cache
else
"0"
)
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment