Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7ee6c259
Unverified
Commit
7ee6c259
authored
Oct 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 12, 2024
Browse files
Simplify the event loop and expose `--num-continuous-decode-steps` as an argument (#1652)
parent
9610fcd4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
85 additions
and
62 deletions
+85
-62
python/sglang/global_config.py
python/sglang/global_config.py
+0
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+59
-60
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+9
-0
No files found.
python/sglang/global_config.py
View file @
7ee6c259
...
...
@@ -19,7 +19,6 @@ class GlobalConfig:
self
.
new_token_ratio_decay
=
0.001
# Runtime constants: others
self
.
num_continue_decode_steps
=
10
self
.
retract_decode_steps
=
20
self
.
flashinfer_workspace_size
=
os
.
environ
.
get
(
"FLASHINFER_WORKSPACE_SIZE"
,
384
*
1024
*
1024
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7ee6c259
...
...
@@ -831,6 +831,22 @@ class ScheduleBatch:
sampling_info
=
self
.
sampling_info
,
)
def
copy
(
self
):
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
tree_cache
=
self
.
tree_cache
,
forward_mode
=
self
.
forward_mode
,
output_token_ids
=
self
.
output_token_ids
,
)
def
__str__
(
self
):
return
(
f
"ScheduleBatch(forward_mode=
{
self
.
forward_mode
.
name
}
, "
f
"#req=
{
(
len
(
self
.
reqs
))
}
)"
)
@
dataclass
class
ModelWorkerBatch
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
7ee6c259
...
...
@@ -20,6 +20,7 @@ import logging
import
os
import
time
import
warnings
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
,
Union
import
torch
...
...
@@ -106,7 +107,8 @@ class Scheduler:
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
f
"ipc://
{
port_args
.
detokenizer_ipc_name
}
"
)
else
:
self
.
recv_from_tokenizer
=
self
.
send_to_detokenizer
=
None
self
.
recv_from_tokenizer
=
None
self
.
send_to_detokenizer
=
SimpleNamespace
(
send_pyobj
=
lambda
x
:
None
)
# Init tokenizer
self
.
model_config
=
ModelConfig
(
...
...
@@ -190,7 +192,6 @@ class Scheduler:
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
ScheduleBatch
=
None
self
.
out_pyobjs
=
[]
self
.
decode_forward_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
self
.
num_generated_tokens
=
0
...
...
@@ -247,13 +248,30 @@ class Scheduler:
@
torch
.
inference_mode
()
def
event_loop
(
self
):
self
.
last_batch
=
None
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
run_step
()
batch
=
self
.
get_next_batch_to_run
()
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
self
.
send_results
()
# Decode multiple steps to reduce the overhead
if
batch
.
forward_mode
.
is_decode
():
for
_
in
range
(
self
.
server_args
.
num_continuous_decode_steps
-
1
):
if
not
self
.
running_batch
:
break
self
.
update_running_batch
()
if
not
self
.
running_batch
:
break
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
self
.
last_batch
=
batch
def
recv_requests
(
self
):
if
self
.
tp_rank
==
0
:
...
...
@@ -286,7 +304,9 @@ class Scheduler:
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightReqInput
):
success
,
message
=
self
.
update_weights
(
recv_req
)
self
.
out_pyobjs
.
append
(
UpdateWeightReqOutput
(
success
,
message
))
self
.
send_to_detokenizer
.
send_pyobj
(
UpdateWeightReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
ProfileReq
):
if
recv_req
==
ProfileReq
.
START_PROFILE
:
self
.
start_profile
()
...
...
@@ -384,12 +404,6 @@ class Scheduler:
self
.
waiting_queue
.
append
(
req
)
def
send_results
(
self
):
if
self
.
tp_rank
==
0
:
for
obj
in
self
.
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
out_pyobjs
=
[]
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
...
...
@@ -427,44 +441,32 @@ class Scheduler:
)
exit
(
1
)
if
crash_on_warning
else
None
def
run_step
(
self
):
def
get_next_batch_to_run
(
self
):
# Merge prefill to the running batch
if
(
self
.
last_batch
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
is_empty
()
):
if
self
.
running_batch
is
None
:
self
.
running_batch
=
self
.
last_batch
else
:
self
.
running_batch
.
merge_batch
(
self
.
last_batch
)
# Prefill first
new_batch
=
self
.
get_new_batch_prefill
()
if
new_batch
is
not
None
:
# Run a new prefill batch
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
# )
result
=
self
.
run_batch
(
new_batch
)
self
.
process_batch_result
(
new_batch
,
result
)
else
:
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
batch
=
self
.
get_new_batch_decode
()
if
batch
:
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_decode_step",
# self.run_batch,
# batch,
# data_size=len(batch.reqs),
# )
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
return
new_batch
if
self
.
running_batch
is
None
:
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
# Run decode
if
self
.
running_batch
is
not
None
:
self
.
update_running_batch
()
if
not
self
.
running_batch
:
return
None
return
self
.
running_batch
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
# Handle the cases where prefill is not allowed
...
...
@@ -607,7 +609,7 @@ class Scheduler:
return
new_batch
def
get_new_batch_decode
(
self
)
->
Optional
[
ScheduleBatch
]
:
def
update_running_batch
(
self
)
:
batch
=
self
.
running_batch
# Check if decode out of memory
...
...
@@ -636,11 +638,11 @@ class Scheduler:
if
jump_forward_reqs
:
self
.
batch_is_full
=
False
if
batch
.
is_empty
():
return
None
self
.
running_batch
=
None
return
# Update batch tensors
batch
.
prepare_for_decode
()
return
batch
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
if
self
.
is_generation
:
...
...
@@ -657,16 +659,19 @@ class Scheduler:
)
else
:
next_token_ids
=
torch
.
full
((
batch
.
batch_size
(),),
0
)
ret
urn
logits_output
,
next_token_ids
ret
=
logits_output
,
next_token_ids
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
return
embeddings
ret
=
embeddings
return
ret
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
if
batch
.
is_empty
():
self
.
running_batch
=
None
else
:
self
.
process_batch_result_prefill
(
batch
,
result
)
...
...
@@ -728,7 +733,7 @@ class Scheduler:
)
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
embeddings
=
result
embeddings
=
result
.
tolist
()
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
...
@@ -750,12 +755,6 @@ class Scheduler:
self
.
handle_finished_requests
(
batch
)
if
not
batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
batch
else
:
self
.
running_batch
.
merge_batch
(
batch
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
if
batch
.
sampling_info
.
penalizer_orchestrator
:
...
...
@@ -951,7 +950,7 @@ class Scheduler:
# Send to detokenizer
if
output_rids
:
if
self
.
is_generation
:
self
.
out_pyobjs
.
append
(
self
.
send_to_detokenizer
.
send_pyobj
(
BatchTokenIDOut
(
output_rids
,
output_vids
,
...
...
@@ -965,7 +964,7 @@ class Scheduler:
)
)
else
:
# embedding or reward model
self
.
out_pyobjs
.
append
(
self
.
send_to_detokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
output_rids
,
output_embeddings
,
...
...
python/sglang/srt/managers/tp_worker.py
View file @
7ee6c259
...
...
@@ -118,7 +118,7 @@ class TpModelWorker:
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
embeddings
=
logits_output
.
embeddings
return
embeddings
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
...
...
python/sglang/srt/server_args.py
View file @
7ee6c259
...
...
@@ -111,6 +111,7 @@ class ServerArgs:
torchao_config
:
str
=
""
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
num_continuous_decode_steps
:
int
=
1
def
__post_init__
(
self
):
# Set missing default values
...
...
@@ -559,6 +560,14 @@ class ServerArgs:
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
)
parser
.
add_argument
(
"--num-continuous-decode-steps"
,
type
=
int
,
default
=
ServerArgs
.
num_continuous_decode_steps
,
help
=
"Run multiple continuous decoding steps to reduce scheduling overhead. "
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time."
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment