Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9116b289
Unverified
Commit
9116b289
authored
Oct 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 16, 2024
Browse files
Add a new event loop (#1677)
parent
a5114b6f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
161 additions
and
25 deletions
+161
-25
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+49
-14
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+11
-5
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+22
-6
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_overlap_schedule.py
test/srt/test_overlap_schedule.py
+65
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
9116b289
...
@@ -736,6 +736,10 @@ class ScheduleBatch:
...
@@ -736,6 +736,10 @@ class ScheduleBatch:
self
.
input_ids
=
self
.
output_ids
self
.
input_ids
=
self
.
output_ids
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
self
.
output_ids
=
None
self
.
output_ids
=
None
if
self
.
sampling_info
.
penalizer_orchestrator
:
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
self
.
input_ids
)
# Alloc mem
# Alloc mem
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
9116b289
...
@@ -20,6 +20,7 @@ import logging
...
@@ -20,6 +20,7 @@ import logging
import
os
import
os
import
time
import
time
import
warnings
import
warnings
from
collections
import
deque
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
...
@@ -192,9 +193,20 @@ class Scheduler:
...
@@ -192,9 +193,20 @@ class Scheduler:
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
)
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
)
if
self
.
server_args
.
enable_overlap_schedule
:
def
cache_finished_req
(
req
):
free_delta
=
int
(
self
.
running_batch
and
req
in
self
.
cur_batch
.
reqs
)
self
.
tree_cache
.
cache_finished_req
(
req
,
free_delta
=
free_delta
)
else
:
cache_finished_req
=
self
.
tree_cache
.
cache_finished_req
self
.
cache_finished_req
=
cache_finished_req
# Init running status
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
cur_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
decode_forward_ct
=
0
self
.
decode_forward_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
self
.
stream_interval
=
server_args
.
stream_interval
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
...
@@ -279,6 +291,32 @@ class Scheduler:
...
@@ -279,6 +291,32 @@ class Scheduler:
self
.
last_batch
=
batch
self
.
last_batch
=
batch
@
torch
.
inference_mode
()
def
event_loop_overlap
(
self
):
result_queue
=
deque
()
self
.
last_batch
=
None
self
.
running_batch
=
None
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
batch
=
self
.
get_next_batch_to_run
()
self
.
cur_batch
=
batch
if
batch
:
result
=
self
.
run_batch
(
batch
)
result_queue
.
append
((
batch
.
copy
(),
result
))
if
self
.
last_batch
:
tmp_batch
,
tmp_result
=
result_queue
.
popleft
()
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
elif
batch
is
None
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
last_batch
=
batch
def
recv_requests
(
self
):
def
recv_requests
(
self
):
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
recv_reqs
=
[]
recv_reqs
=
[]
...
@@ -705,11 +743,6 @@ class Scheduler:
...
@@ -705,11 +743,6 @@ class Scheduler:
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
if
self
.
is_generation
:
logits_output
,
next_token_ids
=
result
logits_output
,
next_token_ids
=
result
if
batch
.
sampling_info
.
penalizer_orchestrator
:
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
# Move logprobs to cpu
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
if
logits_output
.
next_token_logprobs
is
not
None
:
...
@@ -742,7 +775,7 @@ class Scheduler:
...
@@ -742,7 +775,7 @@ class Scheduler:
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
cache_finished_req
(
req
)
elif
not
batch
.
decoding_reqs
or
req
not
in
batch
.
decoding_reqs
:
elif
not
batch
.
decoding_reqs
or
req
not
in
batch
.
decoding_reqs
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
...
@@ -771,7 +804,7 @@ class Scheduler:
...
@@ -771,7 +804,7 @@ class Scheduler:
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
cache_finished_req
(
req
)
else
:
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
...
@@ -779,10 +812,6 @@ class Scheduler:
...
@@ -779,10 +812,6 @@ class Scheduler:
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
logits_output
,
next_token_ids
=
result
if
batch
.
sampling_info
.
penalizer_orchestrator
:
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
# Move logprobs to cpu
...
@@ -796,6 +825,9 @@ class Scheduler:
...
@@ -796,6 +825,9 @@ class Scheduler:
# Check finish condition
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
self
.
server_args
.
enable_overlap_schedule
and
req
.
finished
():
continue
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
req
.
check_finished
()
...
@@ -806,7 +838,7 @@ class Scheduler:
...
@@ -806,7 +838,7 @@ class Scheduler:
)
)
if
req
.
finished
():
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
cache_finished_req
(
req
)
if
req
.
return_logprob
:
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
req
.
output_token_logprobs
.
append
(
...
@@ -1027,7 +1059,7 @@ class Scheduler:
...
@@ -1027,7 +1059,7 @@ class Scheduler:
for
req
in
self
.
running_batch
.
reqs
:
for
req
in
self
.
running_batch
.
reqs
:
if
req
.
rid
==
recv_req
.
rid
and
not
req
.
finished
():
if
req
.
rid
==
recv_req
.
rid
and
not
req
.
finished
():
req
.
finished_reason
=
FINISH_ABORT
()
req
.
finished_reason
=
FINISH_ABORT
()
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
cache_finished_req
(
req
)
break
break
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
...
@@ -1072,7 +1104,10 @@ def run_scheduler_process(
...
@@ -1072,7 +1104,10 @@ def run_scheduler_process(
try
:
try
:
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
)
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
)
pipe_writer
.
send
(
"ready"
)
pipe_writer
.
send
(
"ready"
)
scheduler
.
event_loop_normal
()
if
server_args
.
enable_overlap_schedule
:
scheduler
.
event_loop_overlap
()
else
:
scheduler
.
event_loop_normal
()
except
Exception
:
except
Exception
:
msg
=
get_exception_traceback
()
msg
=
get_exception_traceback
()
logger
.
error
(
msg
)
logger
.
error
(
msg
)
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
9116b289
...
@@ -38,12 +38,16 @@ class ChunkCache(BasePrefixCache):
...
@@ -38,12 +38,16 @@ class ChunkCache(BasePrefixCache):
max_prefix_len
=
len
(
key
)
max_prefix_len
=
len
(
key
)
return
entry
.
value
[:
max_prefix_len
],
entry
return
entry
.
value
[:
max_prefix_len
],
entry
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
,
free_delta
:
int
=
0
):
if
token_ids
is
None
:
if
token_ids
is
None
:
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
token_id_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
else
:
token_id_len
=
len
(
token_ids
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_id
s
)
req
.
req_pool_idx
,
:
token_id
_len
+
free_delta
]
]
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
...
@@ -53,10 +57,12 @@ class ChunkCache(BasePrefixCache):
...
@@ -53,10 +57,12 @@ class ChunkCache(BasePrefixCache):
def
cache_unfinished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
if
token_ids
is
None
:
if
token_ids
is
None
:
token_ids
=
req
.
fill_ids
token_id_len
=
len
(
req
.
fill_ids
)
else
:
token_id_len
=
len
(
token_ids
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_id
s
)
req
.
req_pool_idx
,
:
token_id
_len
]
]
if
req
.
rid
not
in
self
.
entries
:
if
req
.
rid
not
in
self
.
entries
:
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
9116b289
...
@@ -97,22 +97,38 @@ class RadixCache(BasePrefixCache):
...
@@ -97,22 +97,38 @@ class RadixCache(BasePrefixCache):
value
=
[
x
for
x
in
key
]
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
,
free_delta
:
int
=
0
):
"""Cache request when it finishes."""
"""Cache request when it finishes."""
if
self
.
disable
:
if
token_ids
is
None
:
token_ids_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
else
:
token_ids_len
=
len
(
token_ids
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
token_ids_len
+
free_delta
]
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
return
if
token_ids
is
None
:
if
token_ids
is
None
:
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
]
if
self
.
disable
:
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
return
# Radix Cache takes one ref in memory pool
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
if
free_delta
:
self
.
token_to_kv_pool
.
free
(
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
len
(
token_ids
)
:
len
(
token_ids
)
+
1
]
)
# Remove req slot release the cache lock
# Remove req slot release the cache lock
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
...
python/sglang/srt/server.py
View file @
9116b289
...
@@ -528,6 +528,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
...
@@ -528,6 +528,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
kill_child_process
(
pid
,
including_parent
=
False
)
kill_child_process
(
pid
,
including_parent
=
False
)
return
return
# print(f"{res.json()=}")
logger
.
info
(
"The server is fired up and ready to roll!"
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"ready"
)
pipe_finish_writer
.
send
(
"ready"
)
...
...
python/sglang/srt/server_args.py
View file @
9116b289
...
@@ -113,6 +113,7 @@ class ServerArgs:
...
@@ -113,6 +113,7 @@ class ServerArgs:
disable_custom_all_reduce
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_mla
:
bool
=
False
disable_mla
:
bool
=
False
disable_penalizer
:
bool
=
False
disable_penalizer
:
bool
=
False
enable_overlap_schedule
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
max_torch_compile_bs
:
int
=
32
max_torch_compile_bs
:
int
=
32
...
@@ -572,6 +573,11 @@ class ServerArgs:
...
@@ -572,6 +573,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable the logit penalizer (e.g., frequency and repetition penalty)."
,
help
=
"Disable the logit penalizer (e.g., frequency and repetition penalty)."
,
)
)
parser
.
add_argument
(
"--enable-overlap-schedule"
,
action
=
"store_true"
,
help
=
"Overlap the CPU scheduler with GPU model worker. Experimental feature."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-mixed-chunk"
,
"--enable-mixed-chunk"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
python/sglang/srt/utils.py
View file @
9116b289
...
@@ -584,6 +584,7 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
...
@@ -584,6 +584,7 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
def
configure_logger
(
server_args
,
prefix
:
str
=
""
):
def
configure_logger
(
server_args
,
prefix
:
str
=
""
):
format
=
f
"[%(asctime)s
{
prefix
}
] %(message)s"
format
=
f
"[%(asctime)s
{
prefix
}
] %(message)s"
# format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
logging
.
basicConfig
(
logging
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
format
=
format
,
format
=
format
,
...
...
test/srt/run_suite.py
View file @
9116b289
...
@@ -17,6 +17,7 @@ suites = {
...
@@ -17,6 +17,7 @@ suites = {
"test_json_constrained.py"
,
"test_json_constrained.py"
,
"test_large_max_new_tokens.py"
,
"test_large_max_new_tokens.py"
,
"test_openai_server.py"
,
"test_openai_server.py"
,
"test_overlap_schedule.py"
,
"test_pytorch_sampling_backend.py"
,
"test_pytorch_sampling_backend.py"
,
"test_retract_decode.py"
,
"test_retract_decode.py"
,
"test_server_args.py"
,
"test_server_args.py"
,
...
...
test/srt/test_overlap_schedule.py
0 → 100644
View file @
9116b289
"""
Usage:
SGLANG_IS_IN_CI=true python3 -m unittest test_overlap_schedule.TestOverlapSchedule.test_radix_attention_chunked_prefill
SGLANG_IS_IN_CI=true python3 test_overlap_schedule.py
"""
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestOverlapSchedule
(
unittest
.
TestCase
):
def
run_mmlu
(
self
,
disable_radix_cache
,
chunked_prefill_size
=
32
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
other_args
+=
[
"--enable-overlap-schedule"
]
model
=
DEFAULT_MODEL_NAME_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
process
=
popen_launch_server
(
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
args
=
SimpleNamespace
(
base_url
=
base_url
,
model
=
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
try
:
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
finally
:
kill_child_process
(
process
.
pid
)
def
test_no_radix_attention_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
,
chunked_prefill_size
=
32
)
def
test_no_radix_attention_no_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
,
chunked_prefill_size
=-
1
)
def
test_radix_attention_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
,
chunked_prefill_size
=
32
)
def
test_radix_attention_no_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
,
chunked_prefill_size
=-
1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
# @unittest.skip("did not support")
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment