Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ebbc42d9
Commit
ebbc42d9
authored
Oct 07, 2024
by
Lianmin Zheng
Browse files
Optimize broadcast & Reorg code (#1598)
parent
3ff64113
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
47 deletions
+55
-47
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+46
-45
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+8
-1
No files found.
python/sglang/srt/layers/activation.py
View file @
ebbc42d9
...
...
@@ -148,6 +148,6 @@ def get_act_fn(
if
not
is_flashinfer_available
():
logger
.
info
(
"FlashInfer is not available on Non-NV
GPU
s. Fallback to other kernel libraries."
"FlashInfer is not available on Non-NV
platform
s. Fallback to other kernel libraries."
)
from
vllm.model_executor.layers.activation
import
GeluAndMul
,
SiluAndMul
python/sglang/srt/managers/scheduler.py
View file @
ebbc42d9
...
...
@@ -234,14 +234,9 @@ class Scheduler:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
# Run one step
self
.
run_step
()
# Send results
if
self
.
tp_rank
==
0
:
for
obj
in
self
.
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
out_pyobjs
=
[]
self
.
send_results
()
def
recv_requests
(
self
):
if
self
.
tp_rank
==
0
:
...
...
@@ -256,7 +251,8 @@ class Scheduler:
else
:
recv_reqs
=
None
recv_reqs
=
broadcast_pyobj
(
recv_reqs
,
self
.
tp_rank
,
self
.
tp_cpu_group
)
if
self
.
tp_size
!=
1
:
recv_reqs
=
broadcast_pyobj
(
recv_reqs
,
self
.
tp_rank
,
self
.
tp_cpu_group
)
return
recv_reqs
def
process_input_requests
(
self
,
recv_reqs
:
List
):
...
...
@@ -366,43 +362,11 @@ class Scheduler:
self
.
waiting_queue
.
append
(
req
)
def
run_step
(
self
):
new_batch
=
self
.
get_new_batch_prefill
()
if
new_batch
is
not
None
:
# Run a new prefill batch
result
=
self
.
run_batch
(
new_batch
)
self
.
process_batch_result
(
new_batch
,
result
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
else
:
self
.
running_batch
.
merge_batch
(
new_batch
)
else
:
# Run a decode batch
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
batch
=
self
.
get_new_batch_decode
()
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
send_results
(
self
):
if
self
.
tp_rank
==
0
:
for
obj
in
self
.
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
out_pyobjs
=
[]
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
...
...
@@ -441,6 +405,31 @@ class Scheduler:
)
exit
(
1
)
if
crash_on_warning
else
None
def
run_step
(
self
):
new_batch
=
self
.
get_new_batch_prefill
()
if
new_batch
is
not
None
:
# Run a new prefill batch
result
=
self
.
run_batch
(
new_batch
)
self
.
process_batch_result
(
new_batch
,
result
)
else
:
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
batch
=
self
.
get_new_batch_decode
()
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
if
self
.
running_batch
is
None
:
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
# Handle the cases where prefill is not allowed
if
(
...
...
@@ -612,7 +601,6 @@ class Scheduler:
return
None
# Update batch tensors
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
return
batch
...
...
@@ -723,6 +711,12 @@ class Scheduler:
self
.
handle_finished_requests
(
batch
)
if
not
batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
batch
else
:
self
.
running_batch
.
merge_batch
(
batch
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
=
result
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
...
...
@@ -762,6 +756,13 @@ class Scheduler:
self
.
handle_finished_requests
(
batch
)
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
def
add_logprob_return_values
(
self
,
i
:
int
,
...
...
python/sglang/srt/utils.py
View file @
ebbc42d9
...
...
@@ -24,6 +24,7 @@ import random
import
resource
import
socket
import
time
import
warnings
from
importlib.metadata
import
PackageNotFoundError
,
version
from
io
import
BytesIO
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
...
...
@@ -333,6 +334,10 @@ def suppress_other_loggers():
logging
.
getLogger
(
"vllm.selector"
).
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.utils"
).
setLevel
(
logging
.
ERROR
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
message
=
"The given NumPy array is not writable"
)
def
assert_pkg_version
(
pkg
:
str
,
min_version
:
str
,
message
:
str
):
try
:
...
...
@@ -615,7 +620,9 @@ def broadcast_pyobj(
else
:
serialized_data
=
pickle
.
dumps
(
data
)
size
=
len
(
serialized_data
)
tensor_data
=
torch
.
ByteTensor
(
list
(
serialized_data
))
tensor_data
=
torch
.
ByteTensor
(
np
.
frombuffer
(
serialized_data
,
dtype
=
np
.
uint8
)
)
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment