Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
da1ffed6
Unverified
Commit
da1ffed6
authored
Oct 13, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 13, 2024
Browse files
Add output_ids into ScheduleBatch (#1659)
parent
48761171
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
24 deletions
+29
-24
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+7
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+11
-11
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+11
-9
No files found.
python/sglang/bench_latency.py
View file @
da1ffed6
...
...
@@ -232,17 +232,18 @@ def extend(reqs, model_runner):
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_batch
)
.
tolist
()
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_batch
)
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
@
torch
.
inference_mode
()
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
)
batch
.
output_ids
=
input_token_ids
batch
.
prepare_for_decode
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_batch
)
.
tolist
()
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_batch
)
return
next_token_ids
,
logits_output
.
next_token_logits
...
...
@@ -252,6 +253,7 @@ def correctness_test(
bench_args
,
tp_rank
,
):
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
rank_print
=
print
if
tp_rank
==
0
else
lambda
*
args
,
**
kwargs
:
None
# Load the model
...
...
@@ -279,8 +281,9 @@ def correctness_test(
output_ids
=
[
input_ids
[
i
]
+
[
next_token_ids
[
i
]]
for
i
in
range
(
len
(
input_ids
))]
for
_
in
range
(
bench_args
.
output_len
[
0
]
-
1
):
next_token_ids
,
_
=
decode
(
next_token_ids
,
batch
,
model_runner
)
next_token_ids_list
=
next_token_ids
.
tolist
()
for
i
in
range
(
len
(
reqs
)):
output_ids
[
i
].
append
(
next_token_ids
[
i
])
output_ids
[
i
].
append
(
next_token_ids
_list
[
i
])
# Print
for
i
in
range
(
len
(
reqs
)):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
da1ffed6
...
...
@@ -410,6 +410,8 @@ class ScheduleBatch:
seq_lens
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
output_ids
:
torch
.
Tensor
=
None
# For processing logprobs
return_logprob
:
bool
=
False
top_logprobs_nums
:
Optional
[
List
[
int
]]
=
None
...
...
@@ -720,19 +722,12 @@ class ScheduleBatch:
return
jump_forward_reqs
def
prepare_for_decode
(
self
,
input_ids
=
None
):
def
prepare_for_decode
(
self
):
self
.
forward_mode
=
ForwardMode
.
DECODE
if
input_ids
is
None
:
input_ids
=
[
r
.
output_ids
[
-
1
]
if
r
.
output_ids
else
r
.
origin_input_ids
[
-
1
]
for
r
in
self
.
reqs
]
self
.
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
seq_lens
.
device
)
self
.
input_ids
=
self
.
output_ids
self
.
seq_lens
.
add_
(
1
)
self
.
output_ids
=
None
# Alloc mem
bs
=
len
(
self
.
reqs
)
...
...
@@ -759,6 +754,7 @@ class ScheduleBatch:
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
seq_lens
=
self
.
seq_lens
[
new_indices
]
self
.
out_cache_loc
=
None
self
.
output_ids
=
self
.
output_ids
[
new_indices
]
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
if
self
.
return_logprob
:
self
.
top_logprobs_nums
=
[
...
...
@@ -783,6 +779,8 @@ class ScheduleBatch:
)
self
.
seq_lens
=
torch
.
concat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
out_cache_loc
=
None
if
self
.
output_ids
is
not
None
:
self
.
output_ids
=
torch
.
concat
([
self
.
output_ids
,
other
.
output_ids
])
if
self
.
return_logprob
and
other
.
return_logprob
:
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
elif
self
.
return_logprob
:
...
...
@@ -838,7 +836,9 @@ class ScheduleBatch:
token_to_kv_pool
=
self
.
token_to_kv_pool
,
tree_cache
=
self
.
tree_cache
,
forward_mode
=
self
.
forward_mode
,
output_token_ids
=
self
.
output_token_ids
,
output_ids
=
self
.
output_ids
,
sampling_info
=
self
.
sampling_info
,
decoding_reqs
=
self
.
decoding_reqs
,
)
def
__str__
(
self
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
da1ffed6
...
...
@@ -247,7 +247,7 @@ class Scheduler:
)
@
torch
.
inference_mode
()
def
event_loop
(
self
):
def
event_loop
_normal
(
self
):
self
.
last_batch
=
None
while
True
:
...
...
@@ -411,9 +411,10 @@ class Scheduler:
throughput
=
self
.
num_generated_tokens
/
(
time
.
time
()
-
self
.
last_stats_tic
)
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
num_running_reqs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
logger
.
info
(
f
"Decode batch. "
f
"#running-req:
{
len
(
self
.
running_
batch
.
reqs
)
}
, "
f
"#running-req:
{
num_
running_reqs
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
...
...
@@ -659,6 +660,7 @@ class Scheduler:
)
else
:
next_token_ids
=
torch
.
full
((
batch
.
batch_size
(),),
0
)
batch
.
output_ids
=
next_token_ids
ret
=
logits_output
,
next_token_ids
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
...
...
@@ -753,7 +755,7 @@ class Scheduler:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
handle_finished_requests
(
batch
)
self
.
stream_output
(
batch
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
...
...
@@ -793,7 +795,7 @@ class Scheduler:
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
self
.
handle_finished_requests
(
batch
)
self
.
stream_output
(
batch
)
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
...
...
@@ -872,7 +874,7 @@ class Scheduler:
return
num_input_logprobs
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
def
stream_output
(
self
,
batch
:
ScheduleBatch
):
output_rids
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
...
...
@@ -949,6 +951,9 @@ class Scheduler:
}
output_meta_info
.
append
(
meta_info
)
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
# Send to detokenizer
if
output_rids
:
if
self
.
is_generation
:
...
...
@@ -976,9 +981,6 @@ class Scheduler:
)
)
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
def
flush_cache
(
self
):
if
len
(
self
.
waiting_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
...
...
@@ -1060,7 +1062,7 @@ def run_scheduler_process(
try
:
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
)
pipe_writer
.
send
(
"ready"
)
scheduler
.
event_loop
()
scheduler
.
event_loop
_normal
()
except
Exception
:
msg
=
get_exception_traceback
()
logger
.
error
(
msg
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment