Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
58d1082e
"src/nni_manager/vscode:/vscode.git/clone" did not exist on "a587648999af697474fc48aff00b870736bfd618"
Unverified
Commit
58d1082e
authored
Oct 06, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 06, 2024
Browse files
Clean up event loop (#1586)
parent
4d086719
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
220 additions
and
205 deletions
+220
-205
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+220
-205
No files found.
python/sglang/srt/managers/scheduler.py
View file @
58d1082e
...
@@ -228,20 +228,14 @@ class Scheduler:
...
@@ -228,20 +228,14 @@ class Scheduler:
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
@
torch
.
inference_mode
()
def
event_loop
(
self
):
def
event_loop
(
self
):
while
True
:
while
True
:
# Receive requests
recv_reqs
=
self
.
recv_requests
()
if
self
.
tp_rank
==
0
:
self
.
process_input_requests
(
recv_reqs
)
recv_reqs
=
self
.
recv_requests_from_zmq
()
else
:
recv_reqs
=
None
# Process requests
# Run one step
recv_reqs
=
broadcast_pyobj
(
recv_reqs
,
self
.
tp_rank
,
self
.
tp_cpu_group
)
self
.
run_step
()
self
.
process_requests
(
recv_reqs
)
# Forward
self
.
forward_step
()
# Send results
# Send results
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
...
@@ -249,19 +243,23 @@ class Scheduler:
...
@@ -249,19 +243,23 @@ class Scheduler:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
out_pyobjs
=
[]
self
.
out_pyobjs
=
[]
def
recv_requests_from_zmq
(
self
):
def
recv_requests
(
self
):
recv_reqs
=
[]
if
self
.
tp_rank
==
0
:
recv_reqs
=
[]
while
True
:
while
True
:
try
:
try
:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
except
zmq
.
ZMQError
:
except
zmq
.
ZMQError
:
break
break
recv_reqs
.
append
(
recv_req
)
recv_reqs
.
append
(
recv_req
)
else
:
recv_reqs
=
None
recv_reqs
=
broadcast_pyobj
(
recv_reqs
,
self
.
tp_rank
,
self
.
tp_cpu_group
)
return
recv_reqs
return
recv_reqs
def
process_requests
(
self
,
recv_reqs
:
List
):
def
process_
input_
requests
(
self
,
recv_reqs
:
List
):
for
recv_req
in
recv_reqs
:
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
self
.
handle_generate_request
(
recv_req
)
...
@@ -279,83 +277,6 @@ class Scheduler:
...
@@ -279,83 +277,6 @@ class Scheduler:
else
:
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
@
torch
.
inference_mode
()
def
forward_step
(
self
):
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
current_inflight_req
is
None
:
new_batch
=
None
else
:
new_batch
=
self
.
get_new_prefill_batch
()
if
new_batch
is
not
None
:
# Run a new prefill batch
self
.
forward_prefill_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
else
:
self
.
running_batch
.
merge_batch
(
new_batch
)
else
:
# Run a decode batch
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
forward_decode_batch
(
self
.
running_batch
)
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
throughput
=
self
.
num_generated_tokens
/
(
time
.
time
()
-
self
.
last_stats_tic
)
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
logger
.
info
(
f
"Decode batch. "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
def
check_memory
(
self
):
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
available_size
!=
self
.
max_total_num_tokens
:
warnings
.
warn
(
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
"Warning: "
f
"available req slots=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
def
handle_generate_request
(
def
handle_generate_request
(
self
,
self
,
recv_req
:
TokenizedGenerateReqInput
,
recv_req
:
TokenizedGenerateReqInput
,
...
@@ -445,7 +366,88 @@ class Scheduler:
...
@@ -445,7 +366,88 @@ class Scheduler:
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
def
run_step
(
self
):
new_batch
=
self
.
get_new_batch_prefill
()
if
new_batch
is
not
None
:
# Run a new prefill batch
result
=
self
.
run_batch
(
new_batch
)
self
.
process_batch_result
(
new_batch
,
result
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
else
:
self
.
running_batch
.
merge_batch
(
new_batch
)
else
:
# Run a decode batch
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
batch
=
self
.
get_new_batch_decode
()
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
throughput
=
self
.
num_generated_tokens
/
(
time
.
time
()
-
self
.
last_stats_tic
)
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
logger
.
info
(
f
"Decode batch. "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
def
check_memory
(
self
):
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
available_size
!=
self
.
max_total_num_tokens
:
warnings
.
warn
(
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
"Warning: "
f
"available req slots=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
# Handle the cases where prefill is not allowed
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
current_inflight_req
is
None
:
return
None
running_bs
=
(
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
)
...
@@ -456,8 +458,8 @@ class Scheduler:
...
@@ -456,8 +458,8 @@ class Scheduler:
# Get priority queue
# Get priority queue
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
# Prefill policy
num_mixed_running
=
running_bs
if
self
.
is_mixed_chunk
else
0
num_mixed_running
=
running_bs
if
self
.
is_mixed_chunk
else
0
adder
=
PrefillAdder
(
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
tree_cache
,
self
.
running_batch
,
self
.
running_batch
,
...
@@ -517,6 +519,8 @@ class Scheduler:
...
@@ -517,6 +519,8 @@ class Scheduler:
if
len
(
can_run_list
)
==
0
:
if
len
(
can_run_list
)
==
0
:
return
None
return
None
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
can_run_list
]
# Print stats
# Print stats
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
if
isinstance
(
self
.
tree_cache
,
RadixCache
):
if
isinstance
(
self
.
tree_cache
,
RadixCache
):
...
@@ -544,7 +548,7 @@ class Scheduler:
...
@@ -544,7 +548,7 @@ class Scheduler:
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_inflight
}
"
)
)
else
:
else
:
logger
.
info
(
logger
.
info
(
...
@@ -555,41 +559,97 @@ class Scheduler:
...
@@ -555,41 +559,97 @@ class Scheduler:
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_inflight
}
"
)
)
#
Return the
new batch
#
Create a
new batch
new_batch
=
ScheduleBatch
.
init_new
(
new_batch
=
ScheduleBatch
.
init_new
(
can_run_list
,
can_run_list
,
self
.
req_to_token_pool
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
tree_cache
,
)
)
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
can_run_list
]
new_batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
return
new_batch
def
forward_prefill_batch
(
self
,
batch
:
ScheduleBatch
):
# Build batch tensors
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
# Mixed-style chunked prefill
decoding_reqs
=
[]
decoding_reqs
=
[]
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
self
.
running_batch
.
prepare_for_decode
()
batch
.
mix_with_running
(
self
.
running_batch
)
new_
batch
.
mix_with_running
(
self
.
running_batch
)
decoding_reqs
=
self
.
running_batch
.
reqs
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
self
.
running_batch
=
None
new_batch
.
decoding_reqs
=
decoding_reqs
return
new_batch
def
get_new_batch_decode
(
self
)
->
Optional
[
ScheduleBatch
]:
batch
=
self
.
running_batch
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
()
self
.
new_token_ratio
=
new_token_ratio
logger
.
info
(
"Decode out of memory happened. "
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
)
# Check for jump-forward
if
not
self
.
disable_regex_jump_forward
:
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
pad_input_ids_func
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
return
None
# Update batch tensors
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
return
batch
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
if
self
.
is_generation
:
if
self
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
model_worker_batch
)
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
else
:
next_token_ids
logits_output
=
None
)
if
self
.
tokenizer
is
not
None
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
else
:
next_token_ids
=
[
0
]
*
len
(
batch
.
reqs
)
return
logits_output
,
next_token_ids
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
return
embeddings
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
else
:
self
.
process_batch_result_prefill
(
batch
,
result
)
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
if
logits_output
:
# Move logprobs to cpu
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
=
(
...
@@ -607,16 +667,7 @@ class Scheduler:
...
@@ -607,16 +667,7 @@ class Scheduler:
logits_output
.
normalized_prompt_logprobs
.
tolist
()
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
else
:
if
self
.
tokenizer
is
None
:
next_token_ids
=
[]
for
req
in
batch
.
reqs
:
next_token_ids
.
append
(
next
(
iter
(
req
.
sampling_params
.
stop_token_ids
))
)
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish conditions
# Check finish conditions
logprob_pt
=
0
logprob_pt
=
0
...
@@ -634,7 +685,7 @@ class Scheduler:
...
@@ -634,7 +685,7 @@ class Scheduler:
if
req
.
finished
():
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
req
not
in
decoding_reqs
:
elif
req
not
in
batch
.
decoding_reqs
:
# To reduce overhead, only cache prefill reqs
# To reduce overhead, only cache prefill reqs
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
...
@@ -646,10 +697,9 @@ class Scheduler:
...
@@ -646,10 +697,9 @@ class Scheduler:
logprob_pt
+=
self
.
add_logprob_return_values
(
logprob_pt
+=
self
.
add_logprob_return_values
(
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
)
)
else
:
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
assert
batch
.
extend_num_tokens
!=
0
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
result
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
# Check finish conditions
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
@@ -671,6 +721,45 @@ class Scheduler:
...
@@ -671,6 +721,45 @@ class Scheduler:
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
def
add_logprob_return_values
(
def
add_logprob_return_values
(
self
,
self
,
i
:
int
,
i
:
int
,
...
@@ -744,80 +833,6 @@ class Scheduler:
...
@@ -744,80 +833,6 @@ class Scheduler:
return
num_input_logprobs
return
num_input_logprobs
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
()
self
.
new_token_ratio
=
new_token_ratio
logger
.
info
(
"Decode out of memory happened. "
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
)
# Check for jump-forward
if
not
self
.
disable_regex_jump_forward
:
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
pad_input_ids_func
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
return
# Update batch tensors
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
has_finished
=
False
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
has_finished
=
True
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
output_rids
=
[]
output_rids
=
[]
output_meta_info
=
[]
output_meta_info
=
[]
...
@@ -829,7 +844,7 @@ class Scheduler:
...
@@ -829,7 +844,7 @@ class Scheduler:
output_read_offsets
=
[]
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
else
:
#
for
embedding model
else
:
# embedding
or reward
model
output_embeddings
=
[]
output_embeddings
=
[]
unfinished_indices
=
[]
unfinished_indices
=
[]
...
@@ -886,7 +901,7 @@ class Scheduler:
...
@@ -886,7 +901,7 @@ class Scheduler:
req
.
normalized_prompt_logprob
,
req
.
normalized_prompt_logprob
,
)
)
output_meta_info
.
append
(
meta_info
)
output_meta_info
.
append
(
meta_info
)
else
:
#
for
embedding model
else
:
# embedding
or reward
model
output_embeddings
.
append
(
req
.
embedding
)
output_embeddings
.
append
(
req
.
embedding
)
meta_info
=
{
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
...
@@ -909,7 +924,7 @@ class Scheduler:
...
@@ -909,7 +924,7 @@ class Scheduler:
output_finished_reason
,
output_finished_reason
,
)
)
)
)
else
:
#
for
embedding model
else
:
# embedding
or reward
model
self
.
out_pyobjs
.
append
(
self
.
out_pyobjs
.
append
(
BatchEmbeddingOut
(
BatchEmbeddingOut
(
output_rids
,
output_rids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment