Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
58d1082e
Unverified
Commit
58d1082e
authored
Oct 06, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 06, 2024
Browse files
Clean up event loop (#1586)
parent
4d086719
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
220 additions
and
205 deletions
+220
-205
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+220
-205
No files found.
python/sglang/srt/managers/scheduler.py
View file @
58d1082e
...
...
@@ -228,20 +228,14 @@ class Scheduler:
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
batch_is_full
=
False
@
torch
.
inference_mode
()
def
event_loop
(
self
):
while
True
:
# Receive requests
if
self
.
tp_rank
==
0
:
recv_reqs
=
self
.
recv_requests_from_zmq
()
else
:
recv_reqs
=
None
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
# Process requests
recv_reqs
=
broadcast_pyobj
(
recv_reqs
,
self
.
tp_rank
,
self
.
tp_cpu_group
)
self
.
process_requests
(
recv_reqs
)
# Forward
self
.
forward_step
()
# Run one step
self
.
run_step
()
# Send results
if
self
.
tp_rank
==
0
:
...
...
@@ -249,19 +243,23 @@ class Scheduler:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
out_pyobjs
=
[]
def
recv_requests_from_zmq
(
self
):
recv_reqs
=
[]
def
recv_requests
(
self
):
if
self
.
tp_rank
==
0
:
recv_reqs
=
[]
while
True
:
try
:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
except
zmq
.
ZMQError
:
break
recv_reqs
.
append
(
recv_req
)
while
True
:
try
:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
except
zmq
.
ZMQError
:
break
recv_reqs
.
append
(
recv_req
)
else
:
recv_reqs
=
None
recv_reqs
=
broadcast_pyobj
(
recv_reqs
,
self
.
tp_rank
,
self
.
tp_cpu_group
)
return
recv_reqs
def
process_requests
(
self
,
recv_reqs
:
List
):
def
process_
input_
requests
(
self
,
recv_reqs
:
List
):
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
...
...
@@ -279,83 +277,6 @@ class Scheduler:
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
@
torch
.
inference_mode
()
def
forward_step
(
self
):
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
current_inflight_req
is
None
:
new_batch
=
None
else
:
new_batch
=
self
.
get_new_prefill_batch
()
if
new_batch
is
not
None
:
# Run a new prefill batch
self
.
forward_prefill_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
else
:
self
.
running_batch
.
merge_batch
(
new_batch
)
else
:
# Run a decode batch
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
forward_decode_batch
(
self
.
running_batch
)
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
throughput
=
self
.
num_generated_tokens
/
(
time
.
time
()
-
self
.
last_stats_tic
)
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
logger
.
info
(
f
"Decode batch. "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
def
check_memory
(
self
):
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
available_size
!=
self
.
max_total_num_tokens
:
warnings
.
warn
(
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
"Warning: "
f
"available req slots=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
def
handle_generate_request
(
self
,
recv_req
:
TokenizedGenerateReqInput
,
...
...
@@ -445,7 +366,88 @@ class Scheduler:
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
def
run_step
(
self
):
new_batch
=
self
.
get_new_batch_prefill
()
if
new_batch
is
not
None
:
# Run a new prefill batch
result
=
self
.
run_batch
(
new_batch
)
self
.
process_batch_result
(
new_batch
,
result
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
else
:
self
.
running_batch
.
merge_batch
(
new_batch
)
else
:
# Run a decode batch
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
batch
=
self
.
get_new_batch_decode
()
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
throughput
=
self
.
num_generated_tokens
/
(
time
.
time
()
-
self
.
last_stats_tic
)
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
logger
.
info
(
f
"Decode batch. "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
def
check_memory
(
self
):
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
available_size
!=
self
.
max_total_num_tokens
:
warnings
.
warn
(
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
"Warning: "
f
"available req slots=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
# Handle the cases where prefill is not allowed
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
current_inflight_req
is
None
:
return
None
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
...
...
@@ -456,8 +458,8 @@ class Scheduler:
# Get priority queue
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
# Prefill policy
num_mixed_running
=
running_bs
if
self
.
is_mixed_chunk
else
0
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
running_batch
,
...
...
@@ -517,6 +519,8 @@ class Scheduler:
if
len
(
can_run_list
)
==
0
:
return
None
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
can_run_list
]
# Print stats
if
self
.
tp_rank
==
0
:
if
isinstance
(
self
.
tree_cache
,
RadixCache
):
...
...
@@ -544,7 +548,7 @@ class Scheduler:
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_inflight
}
"
)
else
:
logger
.
info
(
...
...
@@ -555,41 +559,97 @@ class Scheduler:
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_inflight
}
"
)
#
Return the
new batch
#
Create a
new batch
new_batch
=
ScheduleBatch
.
init_new
(
can_run_list
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
)
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
can_run_list
]
return
new_batch
def
forward_prefill_batch
(
self
,
batch
:
ScheduleBatch
):
# Build batch tensors
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
new_batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
# Mixed-style chunked prefill
decoding_reqs
=
[]
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
batch
.
mix_with_running
(
self
.
running_batch
)
new_
batch
.
mix_with_running
(
self
.
running_batch
)
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
new_batch
.
decoding_reqs
=
decoding_reqs
return
new_batch
def
get_new_batch_decode
(
self
)
->
Optional
[
ScheduleBatch
]:
batch
=
self
.
running_batch
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
()
self
.
new_token_ratio
=
new_token_ratio
logger
.
info
(
"Decode out of memory happened. "
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
)
# Check for jump-forward
if
not
self
.
disable_regex_jump_forward
:
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
pad_input_ids_func
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
return
None
# Update batch tensors
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
return
batch
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
if
self
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
else
:
logits_output
=
None
if
self
.
tokenizer
is
not
None
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
else
:
next_token_ids
=
[
0
]
*
len
(
batch
.
reqs
)
return
logits_output
,
next_token_ids
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
return
embeddings
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
else
:
self
.
process_batch_result_prefill
(
batch
,
result
)
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
if
logits_output
:
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
...
...
@@ -607,16 +667,7 @@ class Scheduler:
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
else
:
if
self
.
tokenizer
is
None
:
next_token_ids
=
[]
for
req
in
batch
.
reqs
:
next_token_ids
.
append
(
next
(
iter
(
req
.
sampling_params
.
stop_token_ids
))
)
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish conditions
logprob_pt
=
0
...
...
@@ -634,7 +685,7 @@ class Scheduler:
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
req
not
in
decoding_reqs
:
elif
req
not
in
batch
.
decoding_reqs
:
# To reduce overhead, only cache prefill reqs
self
.
tree_cache
.
cache_unfinished_req
(
req
)
...
...
@@ -646,10 +697,9 @@ class Scheduler:
logprob_pt
+=
self
.
add_logprob_return_values
(
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
)
else
:
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
embeddings
=
result
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
...
@@ -671,6 +721,45 @@ class Scheduler:
self
.
handle_finished_requests
(
batch
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
def
add_logprob_return_values
(
self
,
i
:
int
,
...
...
@@ -744,80 +833,6 @@ class Scheduler:
return
num_input_logprobs
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
()
self
.
new_token_ratio
=
new_token_ratio
logger
.
info
(
"Decode out of memory happened. "
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
)
# Check for jump-forward
if
not
self
.
disable_regex_jump_forward
:
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
pad_input_ids_func
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
return
# Update batch tensors
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
has_finished
=
False
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
has_finished
=
True
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
output_rids
=
[]
output_meta_info
=
[]
...
...
@@ -829,7 +844,7 @@ class Scheduler:
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
else
:
#
for
embedding model
else
:
# embedding
or reward
model
output_embeddings
=
[]
unfinished_indices
=
[]
...
...
@@ -886,7 +901,7 @@ class Scheduler:
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
else
:
#
for
embedding model
else
:
# embedding
or reward
model
output_embeddings
.
append
(
req
.
embedding
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
...
...
@@ -909,7 +924,7 @@ class Scheduler:
output_finished_reason
,
)
)
else
:
#
for
embedding model
else
:
# embedding
or reward
model
self
.
out_pyobjs
.
append
(
BatchEmbeddingOut
(
output_rids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment