Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dbec2f18
Unverified
Commit
dbec2f18
authored
Oct 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 16, 2024
Browse files
Launch a thread to overlap CPU and GPU (#1687)
parent
e4b367ba
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
142 additions
and
20 deletions
+142
-20
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+29
-18
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+111
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-2
No files found.
python/sglang/srt/managers/scheduler.py
View file @
dbec2f18
...
@@ -193,16 +193,6 @@ class Scheduler:
...
@@ -193,16 +193,6 @@ class Scheduler:
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
)
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
)
if
self
.
server_args
.
enable_overlap_schedule
:
def
cache_finished_req
(
req
):
free_delta
=
int
(
self
.
running_batch
and
req
in
self
.
cur_batch
.
reqs
)
self
.
tree_cache
.
cache_finished_req
(
req
,
free_delta
=
free_delta
)
else
:
cache_finished_req
=
self
.
tree_cache
.
cache_finished_req
self
.
cache_finished_req
=
cache_finished_req
# Init running status
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
...
@@ -245,6 +235,7 @@ class Scheduler:
...
@@ -245,6 +235,7 @@ class Scheduler:
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
# Init profiler
if
os
.
getenv
(
"SGLANG_TORCH_PROFILER_DIR"
,
""
)
==
""
:
if
os
.
getenv
(
"SGLANG_TORCH_PROFILER_DIR"
,
""
)
==
""
:
self
.
profiler
=
None
self
.
profiler
=
None
else
:
else
:
...
@@ -261,6 +252,25 @@ class Scheduler:
...
@@ -261,6 +252,25 @@ class Scheduler:
with_stack
=
True
,
with_stack
=
True
,
)
)
# Init states for overlap schedule
if
self
.
server_args
.
enable_overlap_schedule
:
self
.
forward_batch_generation
=
(
self
.
tp_worker
.
forward_batch_generation_non_blocking
)
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
def
cache_finished_req
(
req
):
free_delta
=
int
(
self
.
running_batch
and
req
in
self
.
cur_batch
.
reqs
)
self
.
tree_cache
.
cache_finished_req
(
req
,
free_delta
=
free_delta
)
self
.
cache_finished_req
=
cache_finished_req
else
:
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
cache_finished_req
=
self
.
tree_cache
.
cache_finished_req
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
event_loop_normal
(
self
):
def
event_loop_normal
(
self
):
self
.
last_batch
=
None
self
.
last_batch
=
None
...
@@ -712,7 +722,7 @@ class Scheduler:
...
@@ -712,7 +722,7 @@ class Scheduler:
if
self
.
is_generation
:
if
self
.
is_generation
:
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
forward_batch_generation
(
model_worker_batch
model_worker_batch
)
)
else
:
else
:
...
@@ -724,12 +734,12 @@ class Scheduler:
...
@@ -724,12 +734,12 @@ class Scheduler:
else
:
else
:
next_token_ids
=
torch
.
full
((
batch
.
batch_size
(),),
0
)
next_token_ids
=
torch
.
full
((
batch
.
batch_size
(),),
0
)
batch
.
output_ids
=
next_token_ids
batch
.
output_ids
=
next_token_ids
ret
=
logits_output
,
next_token_ids
ret
=
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
else
:
# embedding or reward model
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
assert
batch
.
extend_num_tokens
!=
0
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
ret
=
embeddings
ret
=
embeddings
,
model_worker_batch
.
bid
return
ret
return
ret
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
...
@@ -742,7 +752,7 @@ class Scheduler:
...
@@ -742,7 +752,7 @@ class Scheduler:
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
if
self
.
is_generation
:
logits_output
,
next_token_ids
=
result
logits_output
,
next_token_ids
,
bid
=
result
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
# Move logprobs to cpu
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
if
logits_output
.
next_token_logprobs
is
not
None
:
...
@@ -761,7 +771,7 @@ class Scheduler:
...
@@ -761,7 +771,7 @@ class Scheduler:
logits_output
.
normalized_prompt_logprobs
.
tolist
()
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
)
next_token_ids
=
next_token_ids
.
tolist
(
)
next_token_ids
=
self
.
resolve_
next_token_ids
(
bid
,
next_token_ids
)
# Check finish conditions
# Check finish conditions
logprob_pt
=
0
logprob_pt
=
0
...
@@ -790,7 +800,8 @@ class Scheduler:
...
@@ -790,7 +800,8 @@ class Scheduler:
)
)
else
:
# embedding or reward model
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
assert
batch
.
extend_num_tokens
!=
0
embeddings
=
result
.
tolist
()
embeddings
,
bid
=
result
embeddings
=
embeddings
.
tolist
()
# Check finish conditions
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
@@ -811,7 +822,7 @@ class Scheduler:
...
@@ -811,7 +822,7 @@ class Scheduler:
self
.
stream_output
(
batch
.
reqs
)
self
.
stream_output
(
batch
.
reqs
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
logits_output
,
next_token_ids
,
bid
=
result
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
# Move logprobs to cpu
...
@@ -821,7 +832,7 @@ class Scheduler:
...
@@ -821,7 +832,7 @@ class Scheduler:
next_token_ids
,
next_token_ids
,
].
tolist
()
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
(
)
next_token_ids
=
self
.
resolve_
next_token_ids
(
bid
,
next_token_ids
)
# Check finish condition
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
...
...
python/sglang/srt/managers/tp_worker.py
View file @
dbec2f18
...
@@ -17,6 +17,11 @@ limitations under the License.
...
@@ -17,6 +17,11 @@ limitations under the License.
import
json
import
json
import
logging
import
logging
import
threading
import
time
from
queue
import
Queue
import
torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
...
@@ -75,6 +80,7 @@ class TpModelWorker:
...
@@ -75,6 +80,7 @@ class TpModelWorker:
tokenizer_mode
=
server_args
.
tokenizer_mode
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
)
self
.
device
=
self
.
model_runner
.
device
# Profile number of tokens
# Profile number of tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
...
@@ -100,6 +106,9 @@ class TpModelWorker:
...
@@ -100,6 +106,9 @@ class TpModelWorker:
)[
0
]
)[
0
]
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
if
server_args
.
enable_overlap_schedule
:
self
.
init_overlap_status
()
def
get_token_and_memory_info
(
self
):
def
get_token_and_memory_info
(
self
):
return
(
return
(
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
...
@@ -109,6 +118,83 @@ class TpModelWorker:
...
@@ -109,6 +118,83 @@ class TpModelWorker:
self
.
random_seed
,
self
.
random_seed
,
)
)
def
init_overlap_status
(
self
):
self
.
future_logits_output_dict
=
dict
()
self
.
future_logits_output_ct
=
0
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_map
=
torch
.
empty
(
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_output
=
dict
()
self
.
future_event_map
=
dict
()
self
.
forward_queue
=
Queue
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
target
=
self
.
forward_thread_func
,
)
self
.
forward_thread
.
start
()
def
forward_thread_func
(
self
):
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
@
torch
.
inference_mode
()
def
forward_thread_func_
(
self
):
while
True
:
tic1
=
time
.
time
()
model_worker_batch
,
future_logits_output
,
future_next_token_ids
=
(
self
.
forward_queue
.
get
()
)
# Resolve future tokens in the input
# logger.info(f"raw input {model_worker_batch.input_ids=}")
tic2
=
time
.
time
()
resolved_input_ids
=
model_worker_batch
.
input_ids
future_mask
=
resolved_input_ids
<
0
resolved_input_ids
[
future_mask
]
=
self
.
future_token_ids_map
[
-
resolved_input_ids
[
future_mask
]
]
# logger.info(f"resolved input {model_worker_batch.input_ids=}")
# Run forward
logits_output
,
next_token_ids
=
self
.
forward_batch_generation
(
model_worker_batch
)
# Set future values
if
model_worker_batch
.
return_logprob
:
self
.
future_logits_output_dict
[
future_logits_output
]
=
logits_output
# logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
self
.
future_token_ids_map
[
-
future_next_token_ids
]
=
next_token_ids
.
to
(
torch
.
int32
)
# logger.info("Set event")
self
.
future_token_ids_output
[
model_worker_batch
.
bid
]
=
(
next_token_ids
.
tolist
()
)
self
.
future_event_map
[
model_worker_batch
.
bid
].
set
()
if
False
:
tic3
=
time
.
time
()
self
.
acc_time_with_waiting
+=
tic3
-
tic1
self
.
acc_time_without_waiting
+=
tic3
-
tic2
if
self
.
forward_queue
.
qsize
()
==
0
:
logger
.
info
(
f
"
{
self
.
acc_time_with_waiting
=
:.
3
f
}
,
{
self
.
acc_time_without_waiting
=
:.
3
f
}
,
{
self
.
forward_queue
.
qsize
()
=
}
"
)
def
resolve_future_token_ids
(
self
,
bid
:
int
):
self
.
future_event_map
[
bid
].
wait
()
ret
=
self
.
future_token_ids_output
[
bid
]
del
self
.
future_event_map
[
bid
]
return
ret
def
resolve_future_logits_output
(
self
,
future_obj
):
return
self
.
future_logits_output_dict
.
pop
(
future_obj
)
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
...
@@ -121,6 +207,31 @@ class TpModelWorker:
...
@@ -121,6 +207,31 @@ class TpModelWorker:
embeddings
=
logits_output
.
embeddings
embeddings
=
logits_output
.
embeddings
return
embeddings
return
embeddings
def
forward_batch_generation_non_blocking
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
# Allocate output future objects
future_logits_output
=
self
.
future_logits_output_ct
self
.
future_logits_output_ct
+=
1
bs
=
len
(
model_worker_batch
.
seq_lens
)
future_next_token_ids
=
-
torch
.
arange
(
self
.
future_token_ids_ct
+
1
,
self
.
future_token_ids_ct
+
1
+
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
ret
=
future_logits_output
,
future_next_token_ids
self
.
future_event_map
[
model_worker_batch
.
bid
]
=
threading
.
Event
()
self
.
forward_queue
.
put
(
(
model_worker_batch
.
copy
(),
future_logits_output
,
future_next_token_ids
)
)
return
ret
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights
(
success
,
message
=
self
.
model_runner
.
update_weights
(
recv_req
.
model_path
,
recv_req
.
load_format
recv_req
.
model_path
,
recv_req
.
load_format
...
...
python/sglang/srt/server.py
View file @
dbec2f18
...
@@ -447,7 +447,7 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -447,7 +447,7 @@ def _set_envs_and_config(server_args: ServerArgs):
os
.
environ
[
"NCCL_CUMEM_ENABLE"
]
=
"0"
os
.
environ
[
"NCCL_CUMEM_ENABLE"
]
=
"0"
os
.
environ
[
"NCCL_NVLS_ENABLE"
]
=
"0"
os
.
environ
[
"NCCL_NVLS_ENABLE"
]
=
"0"
os
.
environ
[
"TORCH_NCCL_AVOID_RECORD_STREAMS"
]
=
"1"
os
.
environ
[
"TORCH_NCCL_AVOID_RECORD_STREAMS"
]
=
"1"
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"
1
"
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"
4
"
# Set ulimit
# Set ulimit
set_ulimit
()
set_ulimit
()
...
@@ -528,7 +528,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
...
@@ -528,7 +528,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
kill_child_process
(
pid
,
including_parent
=
False
)
kill_child_process
(
pid
,
including_parent
=
False
)
return
return
#
print(f"{res.json()=}")
print
(
f
"
{
res
.
json
()
=
}
"
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment