Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0e7409ad
"src/vscode:/vscode.git/clone" did not exist on "ceb7af277c130735cd0bce9b3524e3640dbce73a"
Unverified
Commit
0e7409ad
authored
Dec 06, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 06, 2024
Browse files
Fix the overlap for xgrammar (#2377)
parent
3cde5eb6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
145 additions
and
133 deletions
+145
-133
docs/references/supported_models.md
docs/references/supported_models.md
+1
-1
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+5
-0
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+5
-5
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+69
-65
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+2
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+9
-8
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+54
-53
No files found.
docs/references/supported_models.md
View file @
0e7409ad
...
...
@@ -106,4 +106,4 @@ def import_new_model_classes():
ModelRegistry
.
models
.
update
(
import_new_model_classes
())
launch_server
(
server_args
)
```
\ No newline at end of file
```
python/sglang/srt/constrained/outlines_backend.py
View file @
0e7409ad
...
...
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
self
.
guide
=
guide
self
.
jump_forward_map
=
jump_forward_map
self
.
state
=
0
self
.
finished
=
False
def
accept_token
(
self
,
token
:
int
):
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
...
...
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
...
...
python/sglang/srt/constrained/xgrammar_backend.py
View file @
0e7409ad
...
...
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
self
.
matcher
=
matcher
self
.
vocab_size
=
vocab_size
self
.
ctx
=
ctx
self
.
finished
=
False
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
...
...
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
self
.
matcher
.
fill_next_token_bitmask
(
vocab_mask
,
idx
)
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
if
vocab_mask
.
device
.
type
!=
logits
.
device
.
type
:
# vocab_mask must then be on the same device as logits
# when applying the token bitmask, so we check and move if needed
vocab_mask
=
vocab_mask
.
to
(
logits
.
device
)
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
.
to
(
device
,
non_blocking
=
True
)
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
apply_token_bitmask_inplace
(
logits
,
vocab_mask
)
def
copy
(
self
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
0e7409ad
...
...
@@ -114,9 +114,6 @@ class Scheduler:
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
# Session info
self
.
sessions
=
{}
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
...
...
@@ -259,6 +256,10 @@ class Scheduler:
self
.
num_generated_tokens
=
0
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
stream_interval
=
server_args
.
stream_interval
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
# Session info
self
.
sessions
=
{}
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
...
...
@@ -356,6 +357,7 @@ class Scheduler:
)
def
watchdog_thread
(
self
):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self
.
watchdog_last_forward_ct
=
0
self
.
watchdog_last_time
=
time
.
time
()
...
...
@@ -433,61 +435,6 @@ class Scheduler:
self
.
last_batch
=
batch
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
# Check if other DP workers have running batches
if
local_batch
is
None
:
num_tokens
=
0
elif
local_batch
.
forward_mode
.
is_decode
():
num_tokens
=
local_batch
.
batch_size
()
else
:
num_tokens
=
local_batch
.
extend_num_tokens
local_num_tokens
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int64
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
)
torch
.
distributed
.
all_gather_into_tensor
(
global_num_tokens
,
local_num_tokens
,
group
=
self
.
tp_cpu_group
,
)
if
local_batch
is
None
and
global_num_tokens
.
max
().
item
()
>
0
:
local_batch
=
self
.
get_idle_batch
()
if
local_batch
is
not
None
:
local_batch
.
global_num_tokens
=
global_num_tokens
.
tolist
()
# Check forward mode for cuda graph
if
not
self
.
server_args
.
disable_cuda_graph
:
forward_mode_state
=
torch
.
tensor
(
(
1
if
local_batch
.
forward_mode
.
is_decode
()
or
local_batch
.
forward_mode
.
is_idle
()
else
0
),
dtype
=
torch
.
int32
,
)
torch
.
distributed
.
all_reduce
(
forward_mode_state
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_cpu_group
,
)
local_batch
.
can_run_dp_cuda_graph
=
forward_mode_state
.
item
()
==
1
return
local_batch
def
get_idle_batch
(
self
):
idle_batch
=
ScheduleBatch
.
init_new
(
[],
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
)
idle_batch
.
prepare_for_idle
()
return
idle_batch
def
recv_requests
(
self
):
if
self
.
tp_rank
==
0
or
self
.
server_args
.
enable_dp_attention
:
recv_reqs
=
[]
...
...
@@ -993,7 +940,7 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
()
.
synchronize
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
...
...
@@ -1049,13 +996,14 @@ class Scheduler:
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
else
:
# being chunked reqs' prefill is not finished
req
.
is_being_chunked
-=
1
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
()
.
synchronize
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
...
...
@@ -1127,10 +1075,11 @@ class Scheduler:
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
()
.
synchronize
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
...
...
@@ -1328,6 +1277,61 @@ class Scheduler:
)
)
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
# Check if other DP workers have running batches
if
local_batch
is
None
:
num_tokens
=
0
elif
local_batch
.
forward_mode
.
is_decode
():
num_tokens
=
local_batch
.
batch_size
()
else
:
num_tokens
=
local_batch
.
extend_num_tokens
local_num_tokens
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int64
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
)
torch
.
distributed
.
all_gather_into_tensor
(
global_num_tokens
,
local_num_tokens
,
group
=
self
.
tp_cpu_group
,
)
if
local_batch
is
None
and
global_num_tokens
.
max
().
item
()
>
0
:
local_batch
=
self
.
get_idle_batch
()
if
local_batch
is
not
None
:
local_batch
.
global_num_tokens
=
global_num_tokens
.
tolist
()
# Check forward mode for cuda graph
if
not
self
.
server_args
.
disable_cuda_graph
:
forward_mode_state
=
torch
.
tensor
(
(
1
if
local_batch
.
forward_mode
.
is_decode
()
or
local_batch
.
forward_mode
.
is_idle
()
else
0
),
dtype
=
torch
.
int32
,
)
torch
.
distributed
.
all_reduce
(
forward_mode_state
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_cpu_group
,
)
local_batch
.
can_run_dp_cuda_graph
=
forward_mode_state
.
item
()
==
1
return
local_batch
def
get_idle_batch
(
self
):
idle_batch
=
ScheduleBatch
.
init_new
(
[],
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
)
idle_batch
.
prepare_for_idle
()
return
idle_batch
def
move_ready_grammar_requests
(
self
):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs
=
0
...
...
@@ -1469,10 +1473,6 @@ def run_scheduler_process(
dp_rank
:
Optional
[
int
],
pipe_writer
,
):
# set cpu affinity to this gpu process
if
get_bool_env_var
(
"SGLANG_SET_CPU_AFFINITY"
):
set_gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if
dp_rank
is
None
and
"SGLANG_DP_RANK"
in
os
.
environ
:
dp_rank
=
int
(
os
.
environ
[
"SGLANG_DP_RANK"
])
...
...
@@ -1482,6 +1482,10 @@ def run_scheduler_process(
else
:
configure_logger
(
server_args
,
prefix
=
f
" DP
{
dp_rank
}
TP
{
tp_rank
}
"
)
# set cpu affinity to this gpu process
if
get_bool_env_var
(
"SGLANG_SET_CPU_AFFINITY"
):
set_gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
suppress_other_loggers
()
parent_process
=
psutil
.
Process
().
parent
()
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
0e7409ad
...
...
@@ -80,6 +80,7 @@ class TpModelWorkerClient:
)
self
.
forward_thread
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
scheduler_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
...
...
@@ -191,7 +192,7 @@ class TpModelWorkerClient:
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
torch
.
get_device_module
(
self
.
device
).
current
_stream
()
.
synchronize
()
self
.
scheduler
_stream
.
synchronize
()
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
0e7409ad
...
...
@@ -158,22 +158,23 @@ class SamplingBatchInfo:
return
# find a grammar from the list
grammar
=
next
(
grammar
for
grammar
in
self
.
grammars
if
grammar
)
first_
grammar
=
next
(
grammar
for
grammar
in
self
.
grammars
if
grammar
)
# maybe we can reuse the existing mask?
self
.
vocab_mask
=
grammar
.
allocate_vocab_mask
(
self
.
vocab_mask
=
first_
grammar
.
allocate_vocab_mask
(
vocab_size
=
self
.
vocab_size
,
batch_size
=
len
(
self
.
temperatures
),
device
=
self
.
device
,
)
self
.
apply_mask
=
type
(
grammar
)
.
apply_vocab_mask
# force to use static method
self
.
apply_mask
=
first_
grammar
.
apply_vocab_mask
# force to use static method
# Apply the mask
for
i
,
grammar
in
enumerate
(
self
.
grammars
):
if
grammar
is
not
None
:
try
:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
except
RuntimeError
:
continue
if
grammar
and
not
grammar
.
finished
:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
# Move the mask to the device if needed
self
.
vocab_mask
=
first_grammar
.
move_vocab_mask
(
self
.
vocab_mask
,
self
.
device
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
...
...
test/srt/test_json_constrained.py
View file @
0e7409ad
"""
python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
"""
import
json
...
...
@@ -11,38 +12,50 @@ import requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
def
setup_class
(
cls
,
backend
:
str
,
disable_overlap
:
bool
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"pattern"
:
"^[
\\
w]+$"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
backend
,
]
if
disable_overlap
:
other_args
+=
[
"--disable-overlap-schedule"
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
class
TestJSONConstrainedOutlinesBackend
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"pattern"
:
"^[
\\
w]+$"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
"outlines"
,
],
)
setup_class
(
cls
,
backend
=
"outlines"
,
disable_overlap
=
False
)
cls
.
check_jump_forward
=
False
@
classmethod
def
tearDownClass
(
cls
):
...
...
@@ -83,11 +96,13 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
# Make sure jump forward is triggered
# NOTE: This is skipped because overlap scheduler does not support jump forward
# self.assertGreater(
# ret["meta_info"]["completion_tokens"],
# ret["meta_info"]["completion_tokens_wo_jump_forward"],
# )
# NOTE: The overlap scheduler does not support jump forward so we only do this test
# when --disable-overlap-schedule is set.
if
self
.
check_jump_forward
:
self
.
assertGreater
(
ret
[
"meta_info"
][
"completion_tokens"
],
ret
[
"meta_info"
][
"completion_tokens_wo_jump_forward"
],
)
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
...
...
@@ -126,32 +141,18 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
list
(
executor
.
map
(
self
.
run_decode
,
json_schemas
))
class
TestJumpForwardOutlinesBackend
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
setup_class
(
cls
,
backend
=
"outlines"
,
disable_overlap
=
True
)
cls
.
check_jump_forward
=
True
class
TestJSONConstrainedXGrammarBackend
(
TestJSONConstrainedOutlinesBackend
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
"xgrammar"
,
],
)
setup_class
(
cls
,
backend
=
"xgrammar"
,
disable_overlap
=
False
)
cls
.
check_jump_forward
=
False
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment