Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0e7409ad
Unverified
Commit
0e7409ad
authored
Dec 06, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 06, 2024
Browse files
Fix the overlap for xgrammar (#2377)
parent
3cde5eb6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
145 additions
and
133 deletions
+145
-133
docs/references/supported_models.md
docs/references/supported_models.md
+1
-1
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+5
-0
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+5
-5
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+69
-65
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+2
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+9
-8
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+54
-53
No files found.
docs/references/supported_models.md
View file @
0e7409ad
...
...
@@ -106,4 +106,4 @@ def import_new_model_classes():
ModelRegistry
.
models
.
update
(
import_new_model_classes
())
launch_server
(
server_args
)
```
\ No newline at end of file
```
python/sglang/srt/constrained/outlines_backend.py
View file @
0e7409ad
...
...
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
self
.
guide
=
guide
self
.
jump_forward_map
=
jump_forward_map
self
.
state
=
0
self
.
finished
=
False
def
accept_token
(
self
,
token
:
int
):
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
...
...
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
...
...
python/sglang/srt/constrained/xgrammar_backend.py
View file @
0e7409ad
...
...
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
self
.
matcher
=
matcher
self
.
vocab_size
=
vocab_size
self
.
ctx
=
ctx
self
.
finished
=
False
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
...
...
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
self
.
matcher
.
fill_next_token_bitmask
(
vocab_mask
,
idx
)
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
if
vocab_mask
.
device
.
type
!=
logits
.
device
.
type
:
# vocab_mask must then be on the same device as logits
# when applying the token bitmask, so we check and move if needed
vocab_mask
=
vocab_mask
.
to
(
logits
.
device
)
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
.
to
(
device
,
non_blocking
=
True
)
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
apply_token_bitmask_inplace
(
logits
,
vocab_mask
)
def
copy
(
self
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
0e7409ad
...
...
@@ -114,9 +114,6 @@ class Scheduler:
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
# Session info
self
.
sessions
=
{}
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
...
...
@@ -259,6 +256,10 @@ class Scheduler:
self
.
num_generated_tokens
=
0
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
stream_interval
=
server_args
.
stream_interval
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
# Session info
self
.
sessions
=
{}
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
...
...
@@ -356,6 +357,7 @@ class Scheduler:
)
def
watchdog_thread
(
self
):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self
.
watchdog_last_forward_ct
=
0
self
.
watchdog_last_time
=
time
.
time
()
...
...
@@ -433,61 +435,6 @@ class Scheduler:
self
.
last_batch
=
batch
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
# Check if other DP workers have running batches
if
local_batch
is
None
:
num_tokens
=
0
elif
local_batch
.
forward_mode
.
is_decode
():
num_tokens
=
local_batch
.
batch_size
()
else
:
num_tokens
=
local_batch
.
extend_num_tokens
local_num_tokens
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int64
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
)
torch
.
distributed
.
all_gather_into_tensor
(
global_num_tokens
,
local_num_tokens
,
group
=
self
.
tp_cpu_group
,
)
if
local_batch
is
None
and
global_num_tokens
.
max
().
item
()
>
0
:
local_batch
=
self
.
get_idle_batch
()
if
local_batch
is
not
None
:
local_batch
.
global_num_tokens
=
global_num_tokens
.
tolist
()
# Check forward mode for cuda graph
if
not
self
.
server_args
.
disable_cuda_graph
:
forward_mode_state
=
torch
.
tensor
(
(
1
if
local_batch
.
forward_mode
.
is_decode
()
or
local_batch
.
forward_mode
.
is_idle
()
else
0
),
dtype
=
torch
.
int32
,
)
torch
.
distributed
.
all_reduce
(
forward_mode_state
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_cpu_group
,
)
local_batch
.
can_run_dp_cuda_graph
=
forward_mode_state
.
item
()
==
1
return
local_batch
def
get_idle_batch
(
self
):
idle_batch
=
ScheduleBatch
.
init_new
(
[],
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
)
idle_batch
.
prepare_for_idle
()
return
idle_batch
def
recv_requests
(
self
):
if
self
.
tp_rank
==
0
or
self
.
server_args
.
enable_dp_attention
:
recv_reqs
=
[]
...
...
@@ -993,7 +940,7 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
()
.
synchronize
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
...
...
@@ -1049,13 +996,14 @@ class Scheduler:
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
else
:
# being chunked reqs' prefill is not finished
req
.
is_being_chunked
-=
1
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
()
.
synchronize
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
...
...
@@ -1127,10 +1075,11 @@ class Scheduler:
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
()
.
synchronize
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
...
...
@@ -1328,6 +1277,61 @@ class Scheduler:
)
)
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
# Check if other DP workers have running batches
if
local_batch
is
None
:
num_tokens
=
0
elif
local_batch
.
forward_mode
.
is_decode
():
num_tokens
=
local_batch
.
batch_size
()
else
:
num_tokens
=
local_batch
.
extend_num_tokens
local_num_tokens
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int64
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
)
torch
.
distributed
.
all_gather_into_tensor
(
global_num_tokens
,
local_num_tokens
,
group
=
self
.
tp_cpu_group
,
)
if
local_batch
is
None
and
global_num_tokens
.
max
().
item
()
>
0
:
local_batch
=
self
.
get_idle_batch
()
if
local_batch
is
not
None
:
local_batch
.
global_num_tokens
=
global_num_tokens
.
tolist
()
# Check forward mode for cuda graph
if
not
self
.
server_args
.
disable_cuda_graph
:
forward_mode_state
=
torch
.
tensor
(
(
1
if
local_batch
.
forward_mode
.
is_decode
()
or
local_batch
.
forward_mode
.
is_idle
()
else
0
),
dtype
=
torch
.
int32
,
)
torch
.
distributed
.
all_reduce
(
forward_mode_state
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_cpu_group
,
)
local_batch
.
can_run_dp_cuda_graph
=
forward_mode_state
.
item
()
==
1
return
local_batch
def
get_idle_batch
(
self
):
idle_batch
=
ScheduleBatch
.
init_new
(
[],
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
)
idle_batch
.
prepare_for_idle
()
return
idle_batch
def
move_ready_grammar_requests
(
self
):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs
=
0
...
...
@@ -1469,10 +1473,6 @@ def run_scheduler_process(
dp_rank
:
Optional
[
int
],
pipe_writer
,
):
# set cpu affinity to this gpu process
if
get_bool_env_var
(
"SGLANG_SET_CPU_AFFINITY"
):
set_gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if
dp_rank
is
None
and
"SGLANG_DP_RANK"
in
os
.
environ
:
dp_rank
=
int
(
os
.
environ
[
"SGLANG_DP_RANK"
])
...
...
@@ -1482,6 +1482,10 @@ def run_scheduler_process(
else
:
configure_logger
(
server_args
,
prefix
=
f
" DP
{
dp_rank
}
TP
{
tp_rank
}
"
)
# set cpu affinity to this gpu process
if
get_bool_env_var
(
"SGLANG_SET_CPU_AFFINITY"
):
set_gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
suppress_other_loggers
()
parent_process
=
psutil
.
Process
().
parent
()
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
0e7409ad
...
...
@@ -80,6 +80,7 @@ class TpModelWorkerClient:
)
self
.
forward_thread
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
scheduler_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
...
...
@@ -191,7 +192,7 @@ class TpModelWorkerClient:
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
torch
.
get_device_module
(
self
.
device
).
current
_stream
()
.
synchronize
()
self
.
scheduler
_stream
.
synchronize
()
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
0e7409ad
...
...
@@ -158,22 +158,23 @@ class SamplingBatchInfo:
return
# find a grammar from the list
grammar
=
next
(
grammar
for
grammar
in
self
.
grammars
if
grammar
)
first_
grammar
=
next
(
grammar
for
grammar
in
self
.
grammars
if
grammar
)
# maybe we can reuse the existing mask?
self
.
vocab_mask
=
grammar
.
allocate_vocab_mask
(
self
.
vocab_mask
=
first_
grammar
.
allocate_vocab_mask
(
vocab_size
=
self
.
vocab_size
,
batch_size
=
len
(
self
.
temperatures
),
device
=
self
.
device
,
)
self
.
apply_mask
=
type
(
grammar
)
.
apply_vocab_mask
# force to use static method
self
.
apply_mask
=
first_
grammar
.
apply_vocab_mask
# force to use static method
# Apply the mask
for
i
,
grammar
in
enumerate
(
self
.
grammars
):
if
grammar
is
not
None
:
try
:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
except
RuntimeError
:
continue
if
grammar
and
not
grammar
.
finished
:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
# Move the mask to the device if needed
self
.
vocab_mask
=
first_grammar
.
move_vocab_mask
(
self
.
vocab_mask
,
self
.
device
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
...
...
test/srt/test_json_constrained.py
View file @
0e7409ad
"""
python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
"""
import
json
...
...
@@ -11,38 +12,50 @@ import requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
def
setup_class
(
cls
,
backend
:
str
,
disable_overlap
:
bool
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"pattern"
:
"^[
\\
w]+$"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
backend
,
]
if
disable_overlap
:
other_args
+=
[
"--disable-overlap-schedule"
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
class
TestJSONConstrainedOutlinesBackend
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"pattern"
:
"^[
\\
w]+$"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
"outlines"
,
],
)
setup_class
(
cls
,
backend
=
"outlines"
,
disable_overlap
=
False
)
cls
.
check_jump_forward
=
False
@
classmethod
def
tearDownClass
(
cls
):
...
...
@@ -83,11 +96,13 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
# Make sure jump forward is triggered
# NOTE: This is skipped because overlap scheduler does not support jump forward
# self.assertGreater(
# ret["meta_info"]["completion_tokens"],
# ret["meta_info"]["completion_tokens_wo_jump_forward"],
# )
# NOTE: The overlap scheduler does not support jump forward so we only do this test
# when --disable-overlap-schedule is set.
if
self
.
check_jump_forward
:
self
.
assertGreater
(
ret
[
"meta_info"
][
"completion_tokens"
],
ret
[
"meta_info"
][
"completion_tokens_wo_jump_forward"
],
)
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
...
...
@@ -126,32 +141,18 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
list
(
executor
.
map
(
self
.
run_decode
,
json_schemas
))
class
TestJumpForwardOutlinesBackend
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
setup_class
(
cls
,
backend
=
"outlines"
,
disable_overlap
=
True
)
cls
.
check_jump_forward
=
True
class
TestJSONConstrainedXGrammarBackend
(
TestJSONConstrainedOutlinesBackend
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
"xgrammar"
,
],
)
setup_class
(
cls
,
backend
=
"xgrammar"
,
disable_overlap
=
False
)
cls
.
check_jump_forward
=
False
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment