Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
12cad0fe
Unverified
Commit
12cad0fe
authored
Oct 19, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 19, 2024
Browse files
Simplify the interface of tp_worker (#1718)
parent
b6cd9036
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
29 deletions
+41
-29
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+25
-28
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+16
-1
No files found.
python/sglang/srt/managers/scheduler.py
View file @
12cad0fe
...
@@ -91,6 +91,7 @@ class Scheduler:
...
@@ -91,6 +91,7 @@ class Scheduler:
port_args
:
PortArgs
,
port_args
:
PortArgs
,
gpu_id
:
int
,
gpu_id
:
int
,
tp_rank
:
int
,
tp_rank
:
int
,
dp_rank
:
Optional
[
int
],
):
):
# Parse args
# Parse args
self
.
server_args
=
server_args
self
.
server_args
=
server_args
...
@@ -144,13 +145,24 @@ class Scheduler:
...
@@ -144,13 +145,24 @@ class Scheduler:
# Launch a tensor parallel worker
# Launch a tensor parallel worker
self
.
tp_worker
=
TpModelWorker
(
self
.
tp_worker
=
TpModelWorker
(
server_args
=
server_args
,
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
dp_rank
=
dp_rank
,
nccl_port
=
port_args
.
nccl_port
,
nccl_port
=
port_args
.
nccl_port
,
)
)
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
self
.
device
=
self
.
tp_worker
.
device
# Init states for overlap schedule
if
self
.
server_args
.
enable_overlap_schedule
:
self
.
forward_batch_generation
=
(
self
.
tp_worker
.
forward_batch_generation_non_blocking
)
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
else
:
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
# Get token and memory info from the model worker
# Get token and memory info from the model worker
(
(
...
@@ -159,11 +171,11 @@ class Scheduler:
...
@@ -159,11 +171,11 @@ class Scheduler:
self
.
max_running_requests
,
self
.
max_running_requests
,
self
.
max_req_input_len
,
self
.
max_req_input_len
,
self
.
random_seed
,
self
.
random_seed
,
self
.
device
,
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
self
.
tp_cpu_group
=
self
.
tp_worker
.
get_tp_cpu_group
()
self
.
pad_input_ids_func
=
self
.
tp_worker
.
get_pad_input_ids_func
()
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
self
.
pad_input_ids_func
=
getattr
(
self
.
tp_worker
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
# Print debug info
# Print debug info
logger
.
info
(
logger
.
info
(
...
@@ -173,9 +185,8 @@ class Scheduler:
...
@@ -173,9 +185,8 @@ class Scheduler:
f
"context_len=
{
self
.
model_config
.
context_len
}
"
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
)
# Init cache
# Init memory pool and cache
self
.
req_to_token_pool
=
self
.
tp_worker
.
model_runner
.
req_to_token_pool
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
=
self
.
tp_worker
.
get_memory_pool
()
self
.
token_to_kv_pool
=
self
.
tp_worker
.
model_runner
.
token_to_kv_pool
if
(
if
(
server_args
.
chunked_prefill_size
is
not
None
server_args
.
chunked_prefill_size
is
not
None
...
@@ -253,20 +264,6 @@ class Scheduler:
...
@@ -253,20 +264,6 @@ class Scheduler:
with_stack
=
True
,
with_stack
=
True
,
)
)
# Init states for overlap schedule
if
self
.
server_args
.
enable_overlap_schedule
:
self
.
forward_batch_generation
=
(
self
.
tp_worker
.
forward_batch_generation_non_blocking
)
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
self
.
cache_finished_req
=
self
.
tree_cache
.
cache_finished_req
else
:
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
cache_finished_req
=
self
.
tree_cache
.
cache_finished_req
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
event_loop_normal
(
self
):
def
event_loop_normal
(
self
):
self
.
last_batch
=
None
self
.
last_batch
=
None
...
@@ -779,7 +776,7 @@ class Scheduler:
...
@@ -779,7 +776,7 @@ class Scheduler:
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
if
req
.
finished
():
self
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
not
batch
.
decoding_reqs
or
req
not
in
batch
.
decoding_reqs
:
elif
not
batch
.
decoding_reqs
or
req
not
in
batch
.
decoding_reqs
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
...
@@ -808,7 +805,7 @@ class Scheduler:
...
@@ -808,7 +805,7 @@ class Scheduler:
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
if
req
.
finished
():
self
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
...
@@ -845,7 +842,7 @@ class Scheduler:
...
@@ -845,7 +842,7 @@ class Scheduler:
)
)
if
req
.
finished
():
if
req
.
finished
():
self
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
if
req
.
return_logprob
:
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
req
.
output_token_logprobs
.
append
(
...
@@ -1069,7 +1066,7 @@ class Scheduler:
...
@@ -1069,7 +1066,7 @@ class Scheduler:
for
req
in
self
.
running_batch
.
reqs
:
for
req
in
self
.
running_batch
.
reqs
:
if
req
.
rid
==
recv_req
.
rid
and
not
req
.
finished
():
if
req
.
rid
==
recv_req
.
rid
and
not
req
.
finished
():
req
.
finished_reason
=
FINISH_ABORT
()
req
.
finished_reason
=
FINISH_ABORT
()
self
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
break
break
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
...
@@ -1112,7 +1109,7 @@ def run_scheduler_process(
...
@@ -1112,7 +1109,7 @@ def run_scheduler_process(
suppress_other_loggers
()
suppress_other_loggers
()
try
:
try
:
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
)
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
dp_rank
)
pipe_writer
.
send
(
"ready"
)
pipe_writer
.
send
(
"ready"
)
if
server_args
.
enable_overlap_schedule
:
if
server_args
.
enable_overlap_schedule
:
scheduler
.
event_loop_overlap
()
scheduler
.
event_loop_overlap
()
...
...
python/sglang/srt/managers/tp_worker.py
View file @
12cad0fe
...
@@ -20,6 +20,7 @@ import logging
...
@@ -20,6 +20,7 @@ import logging
import
threading
import
threading
import
time
import
time
from
queue
import
Queue
from
queue
import
Queue
from
typing
import
Optional
import
torch
import
torch
...
@@ -40,9 +41,10 @@ class TpModelWorker:
...
@@ -40,9 +41,10 @@ class TpModelWorker:
def
__init__
(
def
__init__
(
self
,
self
,
server_args
:
ServerArgs
,
gpu_id
:
int
,
gpu_id
:
int
,
tp_rank
:
int
,
tp_rank
:
int
,
server_args
:
ServerArgs
,
dp_rank
:
Optional
[
int
]
,
nccl_port
:
int
,
nccl_port
:
int
,
):
):
# Parse args
# Parse args
...
@@ -116,6 +118,19 @@ class TpModelWorker:
...
@@ -116,6 +118,19 @@ class TpModelWorker:
self
.
max_running_requests
,
self
.
max_running_requests
,
self
.
max_req_input_len
,
self
.
max_req_input_len
,
self
.
random_seed
,
self
.
random_seed
,
self
.
device
,
)
def
get_pad_input_ids_func
(
self
):
return
getattr
(
self
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
def
get_tp_cpu_group
(
self
):
return
self
.
model_runner
.
tp_group
.
cpu_group
def
get_memory_pool
(
self
):
return
(
self
.
model_runner
.
req_to_token_pool
,
self
.
model_runner
.
token_to_kv_pool
,
)
)
def
init_overlap_status
(
self
):
def
init_overlap_status
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment