Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7ee6c259
"megatron/vscode:/vscode.git/clone" did not exist on "a2e64ad59bd82d015e7b51cd94956fab2fcf8ea0"
Unverified
Commit
7ee6c259
authored
Oct 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 12, 2024
Browse files
Simplify the event loop and expose `--num-continuous-decode-steps` as an argument (#1652)
parent
9610fcd4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
85 additions
and
62 deletions
+85
-62
python/sglang/global_config.py
python/sglang/global_config.py
+0
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+59
-60
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+9
-0
No files found.
python/sglang/global_config.py
View file @
7ee6c259
...
...
@@ -19,7 +19,6 @@ class GlobalConfig:
self
.
new_token_ratio_decay
=
0.001
# Runtime constants: others
self
.
num_continue_decode_steps
=
10
self
.
retract_decode_steps
=
20
self
.
flashinfer_workspace_size
=
os
.
environ
.
get
(
"FLASHINFER_WORKSPACE_SIZE"
,
384
*
1024
*
1024
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7ee6c259
...
...
@@ -831,6 +831,22 @@ class ScheduleBatch:
sampling_info
=
self
.
sampling_info
,
)
def
copy
(
self
):
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
tree_cache
=
self
.
tree_cache
,
forward_mode
=
self
.
forward_mode
,
output_token_ids
=
self
.
output_token_ids
,
)
def
__str__
(
self
):
return
(
f
"ScheduleBatch(forward_mode=
{
self
.
forward_mode
.
name
}
, "
f
"#req=
{
(
len
(
self
.
reqs
))
}
)"
)
@
dataclass
class
ModelWorkerBatch
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
7ee6c259
...
...
@@ -20,6 +20,7 @@ import logging
import
os
import
time
import
warnings
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
,
Union
import
torch
...
...
@@ -106,7 +107,8 @@ class Scheduler:
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
f
"ipc://
{
port_args
.
detokenizer_ipc_name
}
"
)
else
:
self
.
recv_from_tokenizer
=
self
.
send_to_detokenizer
=
None
self
.
recv_from_tokenizer
=
None
self
.
send_to_detokenizer
=
SimpleNamespace
(
send_pyobj
=
lambda
x
:
None
)
# Init tokenizer
self
.
model_config
=
ModelConfig
(
...
...
@@ -190,7 +192,6 @@ class Scheduler:
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
ScheduleBatch
=
None
self
.
out_pyobjs
=
[]
self
.
decode_forward_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
self
.
num_generated_tokens
=
0
...
...
@@ -247,13 +248,30 @@ class Scheduler:
@
torch
.
inference_mode
()
def
event_loop
(
self
):
self
.
last_batch
=
None
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
run_step
()
batch
=
self
.
get_next_batch_to_run
()
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
# Decode multiple steps to reduce the overhead
if
batch
.
forward_mode
.
is_decode
():
for
_
in
range
(
self
.
server_args
.
num_continuous_decode_steps
-
1
):
if
not
self
.
running_batch
:
break
self
.
update_running_batch
()
if
not
self
.
running_batch
:
break
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
self
.
send_results
()
self
.
last_batch
=
batch
def
recv_requests
(
self
):
if
self
.
tp_rank
==
0
:
...
...
@@ -286,7 +304,9 @@ class Scheduler:
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightReqInput
):
success
,
message
=
self
.
update_weights
(
recv_req
)
self
.
out_pyobjs
.
append
(
UpdateWeightReqOutput
(
success
,
message
))
self
.
send_to_detokenizer
.
send_pyobj
(
UpdateWeightReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
ProfileReq
):
if
recv_req
==
ProfileReq
.
START_PROFILE
:
self
.
start_profile
()
...
...
@@ -384,12 +404,6 @@ class Scheduler:
self
.
waiting_queue
.
append
(
req
)
def
send_results
(
self
):
if
self
.
tp_rank
==
0
:
for
obj
in
self
.
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
out_pyobjs
=
[]
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
...
...
@@ -427,41 +441,29 @@ class Scheduler:
)
exit
(
1
)
if
crash_on_warning
else
None
def
run_step
(
self
):
new_batch
=
self
.
get_new_batch_prefill
()
if
new_batch
is
not
None
:
# Run a new prefill batch
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
# )
result
=
self
.
run_batch
(
new_batch
)
self
.
process_batch_result
(
new_batch
,
result
)
def
get_next_batch_to_run
(
self
):
# Merge prefill to the running batch
if
(
self
.
last_batch
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
is_empty
()
):
if
self
.
running_batch
is
None
:
self
.
running_batch
=
self
.
last_batch
else
:
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
batch
=
self
.
get_new_batch_decode
()
self
.
running_batch
.
merge_batch
(
self
.
last_batch
)
if
batch
:
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_decode_step",
# self.run_batch,
# batch,
# data_size=len(batch.reqs),
# )
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
if
self
.
running_batch
is
None
:
break
# Prefill first
new_batch
=
self
.
get_new_batch_prefill
()
if
new_batch
is
not
None
:
return
new_batch
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
# Run decode
if
self
.
running_batch
is
not
None
:
self
.
update_running_batch
()
if
not
self
.
running_batch
:
return
None
return
self
.
running_batch
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
...
...
@@ -607,7 +609,7 @@ class Scheduler:
return
new_batch
def
get_new_batch_decode
(
self
)
->
Optional
[
ScheduleBatch
]
:
def
update_running_batch
(
self
)
:
batch
=
self
.
running_batch
# Check if decode out of memory
...
...
@@ -636,11 +638,11 @@ class Scheduler:
if
jump_forward_reqs
:
self
.
batch_is_full
=
False
if
batch
.
is_empty
():
return
None
self
.
running_batch
=
None
return
# Update batch tensors
batch
.
prepare_for_decode
()
return
batch
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
if
self
.
is_generation
:
...
...
@@ -657,16 +659,19 @@ class Scheduler:
)
else
:
next_token_ids
=
torch
.
full
((
batch
.
batch_size
(),),
0
)
ret
urn
logits_output
,
next_token_ids
ret
=
logits_output
,
next_token_ids
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
return
embeddings
ret
=
embeddings
return
ret
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
if
batch
.
is_empty
():
self
.
running_batch
=
None
else
:
self
.
process_batch_result_prefill
(
batch
,
result
)
...
...
@@ -728,7 +733,7 @@ class Scheduler:
)
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
embeddings
=
result
embeddings
=
result
.
tolist
()
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
...
@@ -750,12 +755,6 @@ class Scheduler:
self
.
handle_finished_requests
(
batch
)
if
not
batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
batch
else
:
self
.
running_batch
.
merge_batch
(
batch
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
if
batch
.
sampling_info
.
penalizer_orchestrator
:
...
...
@@ -951,7 +950,7 @@ class Scheduler:
# Send to detokenizer
if
output_rids
:
if
self
.
is_generation
:
self
.
out_pyobjs
.
append
(
self
.
send_to_detokenizer
.
send_pyobj
(
BatchTokenIDOut
(
output_rids
,
output_vids
,
...
...
@@ -965,7 +964,7 @@ class Scheduler:
)
)
else
:
# embedding or reward model
self
.
out_pyobjs
.
append
(
self
.
send_to_detokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
output_rids
,
output_embeddings
,
...
...
python/sglang/srt/managers/tp_worker.py
View file @
7ee6c259
...
...
@@ -118,7 +118,7 @@ class TpModelWorker:
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
embeddings
=
logits_output
.
embeddings
return
embeddings
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
...
...
python/sglang/srt/server_args.py
View file @
7ee6c259
...
...
@@ -111,6 +111,7 @@ class ServerArgs:
torchao_config
:
str
=
""
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
num_continuous_decode_steps
:
int
=
1
def
__post_init__
(
self
):
# Set missing default values
...
...
@@ -559,6 +560,14 @@ class ServerArgs:
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
)
parser
.
add_argument
(
"--num-continuous-decode-steps"
,
type
=
int
,
default
=
ServerArgs
.
num_continuous_decode_steps
,
help
=
"Run multiple continuous decoding steps to reduce scheduling overhead. "
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time."
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment