Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7ee6c259
Unverified
Commit
7ee6c259
authored
Oct 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 12, 2024
Browse files
Simplify the event loop and expose `--num-continuous-decode-steps` as an argument (#1652)
parent
9610fcd4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
85 additions
and
62 deletions
+85
-62
python/sglang/global_config.py
python/sglang/global_config.py
+0
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+59
-60
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+9
-0
No files found.
python/sglang/global_config.py
View file @
7ee6c259
...
@@ -19,7 +19,6 @@ class GlobalConfig:
...
@@ -19,7 +19,6 @@ class GlobalConfig:
self
.
new_token_ratio_decay
=
0.001
self
.
new_token_ratio_decay
=
0.001
# Runtime constants: others
# Runtime constants: others
self
.
num_continue_decode_steps
=
10
self
.
retract_decode_steps
=
20
self
.
retract_decode_steps
=
20
self
.
flashinfer_workspace_size
=
os
.
environ
.
get
(
self
.
flashinfer_workspace_size
=
os
.
environ
.
get
(
"FLASHINFER_WORKSPACE_SIZE"
,
384
*
1024
*
1024
"FLASHINFER_WORKSPACE_SIZE"
,
384
*
1024
*
1024
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7ee6c259
...
@@ -831,6 +831,22 @@ class ScheduleBatch:
...
@@ -831,6 +831,22 @@ class ScheduleBatch:
sampling_info
=
self
.
sampling_info
,
sampling_info
=
self
.
sampling_info
,
)
)
def
copy
(
self
):
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
tree_cache
=
self
.
tree_cache
,
forward_mode
=
self
.
forward_mode
,
output_token_ids
=
self
.
output_token_ids
,
)
def
__str__
(
self
):
return
(
f
"ScheduleBatch(forward_mode=
{
self
.
forward_mode
.
name
}
, "
f
"#req=
{
(
len
(
self
.
reqs
))
}
)"
)
@
dataclass
@
dataclass
class
ModelWorkerBatch
:
class
ModelWorkerBatch
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
7ee6c259
...
@@ -20,6 +20,7 @@ import logging
...
@@ -20,6 +20,7 @@ import logging
import
os
import
os
import
time
import
time
import
warnings
import
warnings
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch
...
@@ -106,7 +107,8 @@ class Scheduler:
...
@@ -106,7 +107,8 @@ class Scheduler:
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
f
"ipc://
{
port_args
.
detokenizer_ipc_name
}
"
)
self
.
send_to_detokenizer
.
connect
(
f
"ipc://
{
port_args
.
detokenizer_ipc_name
}
"
)
else
:
else
:
self
.
recv_from_tokenizer
=
self
.
send_to_detokenizer
=
None
self
.
recv_from_tokenizer
=
None
self
.
send_to_detokenizer
=
SimpleNamespace
(
send_pyobj
=
lambda
x
:
None
)
# Init tokenizer
# Init tokenizer
self
.
model_config
=
ModelConfig
(
self
.
model_config
=
ModelConfig
(
...
@@ -190,7 +192,6 @@ class Scheduler:
...
@@ -190,7 +192,6 @@ class Scheduler:
# Init running status
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
ScheduleBatch
=
None
self
.
running_batch
:
ScheduleBatch
=
None
self
.
out_pyobjs
=
[]
self
.
decode_forward_ct
=
0
self
.
decode_forward_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
self
.
stream_interval
=
server_args
.
stream_interval
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
...
@@ -247,13 +248,30 @@ class Scheduler:
...
@@ -247,13 +248,30 @@ class Scheduler:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
event_loop
(
self
):
def
event_loop
(
self
):
self
.
last_batch
=
None
while
True
:
while
True
:
recv_reqs
=
self
.
recv_requests
()
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
process_input_requests
(
recv_reqs
)
self
.
run_step
()
batch
=
self
.
get_next_batch_to_run
()
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
self
.
send_results
()
# Decode multiple steps to reduce the overhead
if
batch
.
forward_mode
.
is_decode
():
for
_
in
range
(
self
.
server_args
.
num_continuous_decode_steps
-
1
):
if
not
self
.
running_batch
:
break
self
.
update_running_batch
()
if
not
self
.
running_batch
:
break
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
self
.
last_batch
=
batch
def
recv_requests
(
self
):
def
recv_requests
(
self
):
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
...
@@ -286,7 +304,9 @@ class Scheduler:
...
@@ -286,7 +304,9 @@ class Scheduler:
self
.
abort_request
(
recv_req
)
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightReqInput
):
elif
isinstance
(
recv_req
,
UpdateWeightReqInput
):
success
,
message
=
self
.
update_weights
(
recv_req
)
success
,
message
=
self
.
update_weights
(
recv_req
)
self
.
out_pyobjs
.
append
(
UpdateWeightReqOutput
(
success
,
message
))
self
.
send_to_detokenizer
.
send_pyobj
(
UpdateWeightReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
ProfileReq
):
elif
isinstance
(
recv_req
,
ProfileReq
):
if
recv_req
==
ProfileReq
.
START_PROFILE
:
if
recv_req
==
ProfileReq
.
START_PROFILE
:
self
.
start_profile
()
self
.
start_profile
()
...
@@ -384,12 +404,6 @@ class Scheduler:
...
@@ -384,12 +404,6 @@ class Scheduler:
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
def
send_results
(
self
):
if
self
.
tp_rank
==
0
:
for
obj
in
self
.
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
out_pyobjs
=
[]
def
print_decode_stats
(
self
):
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
...
@@ -427,44 +441,32 @@ class Scheduler:
...
@@ -427,44 +441,32 @@ class Scheduler:
)
)
exit
(
1
)
if
crash_on_warning
else
None
exit
(
1
)
if
crash_on_warning
else
None
def
run_step
(
self
):
def
get_next_batch_to_run
(
self
):
# Merge prefill to the running batch
if
(
self
.
last_batch
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
is_empty
()
):
if
self
.
running_batch
is
None
:
self
.
running_batch
=
self
.
last_batch
else
:
self
.
running_batch
.
merge_batch
(
self
.
last_batch
)
# Prefill first
new_batch
=
self
.
get_new_batch_prefill
()
new_batch
=
self
.
get_new_batch_prefill
()
if
new_batch
is
not
None
:
if
new_batch
is
not
None
:
# Run a new prefill batch
return
new_batch
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
# )
result
=
self
.
run_batch
(
new_batch
)
self
.
process_batch_result
(
new_batch
,
result
)
else
:
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
batch
=
self
.
get_new_batch_decode
()
if
batch
:
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_decode_step",
# self.run_batch,
# batch,
# data_size=len(batch.reqs),
# )
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
if
self
.
running_batch
is
None
:
# Run decode
break
if
self
.
running_batch
is
not
None
:
self
.
update_running_batch
()
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
if
not
self
.
running_batch
:
break
return
None
else
:
return
self
.
running_batch
self
.
check_memory
()
else
:
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
# Handle the cases where prefill is not allowed
# Handle the cases where prefill is not allowed
...
@@ -607,7 +609,7 @@ class Scheduler:
...
@@ -607,7 +609,7 @@ class Scheduler:
return
new_batch
return
new_batch
def
get_new_batch_decode
(
self
)
->
Optional
[
ScheduleBatch
]
:
def
update_running_batch
(
self
)
:
batch
=
self
.
running_batch
batch
=
self
.
running_batch
# Check if decode out of memory
# Check if decode out of memory
...
@@ -636,11 +638,11 @@ class Scheduler:
...
@@ -636,11 +638,11 @@ class Scheduler:
if
jump_forward_reqs
:
if
jump_forward_reqs
:
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
if
batch
.
is_empty
():
if
batch
.
is_empty
():
return
None
self
.
running_batch
=
None
return
# Update batch tensors
# Update batch tensors
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
return
batch
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
if
self
.
is_generation
:
if
self
.
is_generation
:
...
@@ -657,16 +659,19 @@ class Scheduler:
...
@@ -657,16 +659,19 @@ class Scheduler:
)
)
else
:
else
:
next_token_ids
=
torch
.
full
((
batch
.
batch_size
(),),
0
)
next_token_ids
=
torch
.
full
((
batch
.
batch_size
(),),
0
)
ret
urn
logits_output
,
next_token_ids
ret
=
logits_output
,
next_token_ids
else
:
# embedding or reward model
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
assert
batch
.
extend_num_tokens
!=
0
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
return
embeddings
ret
=
embeddings
return
ret
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
self
.
process_batch_result_decode
(
batch
,
result
)
if
batch
.
is_empty
():
self
.
running_batch
=
None
else
:
else
:
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
)
...
@@ -728,7 +733,7 @@ class Scheduler:
...
@@ -728,7 +733,7 @@ class Scheduler:
)
)
else
:
# embedding or reward model
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
assert
batch
.
extend_num_tokens
!=
0
embeddings
=
result
embeddings
=
result
.
tolist
()
# Check finish conditions
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
@@ -750,12 +755,6 @@ class Scheduler:
...
@@ -750,12 +755,6 @@ class Scheduler:
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
if
not
batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
batch
else
:
self
.
running_batch
.
merge_batch
(
batch
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
logits_output
,
next_token_ids
=
result
if
batch
.
sampling_info
.
penalizer_orchestrator
:
if
batch
.
sampling_info
.
penalizer_orchestrator
:
...
@@ -951,7 +950,7 @@ class Scheduler:
...
@@ -951,7 +950,7 @@ class Scheduler:
# Send to detokenizer
# Send to detokenizer
if
output_rids
:
if
output_rids
:
if
self
.
is_generation
:
if
self
.
is_generation
:
self
.
out_pyobjs
.
append
(
self
.
send_to_detokenizer
.
send_pyobj
(
BatchTokenIDOut
(
BatchTokenIDOut
(
output_rids
,
output_rids
,
output_vids
,
output_vids
,
...
@@ -965,7 +964,7 @@ class Scheduler:
...
@@ -965,7 +964,7 @@ class Scheduler:
)
)
)
)
else
:
# embedding or reward model
else
:
# embedding or reward model
self
.
out_pyobjs
.
append
(
self
.
send_to_detokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
BatchEmbeddingOut
(
output_rids
,
output_rids
,
output_embeddings
,
output_embeddings
,
...
...
python/sglang/srt/managers/tp_worker.py
View file @
7ee6c259
...
@@ -118,7 +118,7 @@ class TpModelWorker:
...
@@ -118,7 +118,7 @@ class TpModelWorker:
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
embeddings
=
logits_output
.
embeddings
return
embeddings
return
embeddings
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
...
...
python/sglang/srt/server_args.py
View file @
7ee6c259
...
@@ -111,6 +111,7 @@ class ServerArgs:
...
@@ -111,6 +111,7 @@ class ServerArgs:
torchao_config
:
str
=
""
torchao_config
:
str
=
""
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
num_continuous_decode_steps
:
int
=
1
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set missing default values
# Set missing default values
...
@@ -559,6 +560,14 @@ class ServerArgs:
...
@@ -559,6 +560,14 @@ class ServerArgs:
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
"This only affects Triton attention kernels."
,
)
)
parser
.
add_argument
(
"--num-continuous-decode-steps"
,
type
=
int
,
default
=
ServerArgs
.
num_continuous_decode_steps
,
help
=
"Run multiple continuous decoding steps to reduce scheduling overhead. "
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time."
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment