Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
62757db6
Unverified
Commit
62757db6
authored
Aug 09, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 09, 2024
Browse files
Reduce the overhead when cache is disabled (#1010)
parent
73fa2d49
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
43 deletions
+35
-43
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+20
-25
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+4
-18
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+6
-0
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
62757db6
...
...
@@ -18,44 +18,40 @@ limitations under the License.
import
random
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
List
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
class
PolicyScheduler
:
def
__init__
(
self
,
policy
,
max_running_seqs
,
max_prefill_num_tokens
,
max_total_num_tokens
,
tree_cache
,
):
if
tree_cache
.
disable
and
policy
==
"lpm"
:
# LMP is meaningless when the tree cache is disabled.
def
__init__
(
self
,
policy
,
tree_cache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy
=
"fcfs"
self
.
policy
=
policy
self
.
max_running_seqs
=
max_running_seqs
self
.
max_prefill_num_tokens
=
max_prefill_num_tokens
self
.
max_total_num_tokens
=
max_total_num_tokens
self
.
tree_cache
=
tree_cache
def
get_priority_queue
(
self
,
waiting_queue
):
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
if
self
.
policy
in
[
"lpm"
,
"dfs-weight"
]:
# Compute matched prefix length
for
r
in
waiting_queue
:
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
r
.
adjust_max_prefix_ids
()
)
if
self
.
policy
==
"lpm"
:
#
l
ongest
p
refix
m
atch
#
L
ongest
P
refix
M
atch
waiting_queue
.
sort
(
key
=
lambda
x
:
-
len
(
x
.
prefix_indices
))
return
waiting_queue
elif
self
.
policy
==
"fcfs"
:
# first come first serve
return
waiting_queue
pass
elif
self
.
policy
==
"lof"
:
# longest output first
waiting_queue
.
sort
(
key
=
lambda
x
:
-
x
.
sampling_params
.
max_new_tokens
)
return
waiting_queue
elif
self
.
policy
==
"random"
:
random
.
shuffle
(
waiting_queue
)
return
waiting_queue
elif
self
.
policy
==
"dfs-weight"
:
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
waiting_queue
:
...
...
@@ -66,12 +62,13 @@ class PolicyScheduler:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
self
.
calc_weight
(
self
.
tree_cache
.
root_node
,
node_to_weight
)
q
=
[]
waiting_queue
.
clear
()
self
.
get_dfs_priority
(
self
.
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
q
self
.
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
waiting_queue
,
)
assert
len
(
q
)
==
len
(
waiting_queue
)
return
q
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
...
...
@@ -139,8 +136,6 @@ class PrefillAdder:
self
.
log_input_tokens
+=
extend_input_len
def
add_inflight_req
(
self
,
req
:
Req
):
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
req
.
prefix_indices
)
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
input_ids
=
req
.
input_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
62757db6
...
...
@@ -164,7 +164,12 @@ class Req:
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
self
):
self
.
input_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
extend_input_len
=
len
(
self
.
input_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
self
.
input_ids
=
self
.
origin_input_ids
+
self
.
output_ids
input_len
=
len
(
self
.
input_ids
)
max_prefix_len
=
input_len
...
...
python/sglang/srt/managers/tp_worker.py
View file @
62757db6
...
...
@@ -165,13 +165,7 @@ class ModelTpServer:
disable
=
server_args
.
disable_radix_cache
,
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
PolicyScheduler
(
self
.
schedule_policy
,
self
.
max_running_requests
,
self
.
max_prefill_tokens
,
self
.
max_total_num_tokens
,
self
.
tree_cache
,
)
self
.
scheduler
=
PolicyScheduler
(
self
.
schedule_policy
,
self
.
tree_cache
)
self
.
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
self
.
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
...
...
@@ -373,17 +367,8 @@ class ModelTpServer:
if
running_bs
>=
self
.
max_running_requests
:
return
None
# Compute matched prefix length
for
req
in
self
.
waiting_queue
:
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
# NOTE: the prefix_indices must always be aligned with last_node
req
.
prefix_indices
,
req
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
req
.
rid
,
key
=
req
.
adjust_max_prefix_ids
()
)
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
req
.
prefix_indices
)
# Get priority queue
self
.
waiting_queue
=
self
.
scheduler
.
get
_priority
_queue
(
self
.
waiting_queue
)
self
.
scheduler
.
calc
_priority
(
self
.
waiting_queue
)
adder
=
PrefillAdder
(
self
.
tree_cache
,
...
...
@@ -397,12 +382,13 @@ class ModelTpServer:
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
()
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
for
req
in
self
.
waiting_queue
:
req
.
init_next_round_input
()
res
=
adder
.
add_one_req
(
req
)
if
(
not
res
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
62757db6
...
...
@@ -169,6 +169,9 @@ class RadixCache(BasePrefixCache):
heapq
.
heappush
(
leaves
,
x
.
parent
)
def
inc_lock_ref
(
self
,
node
:
TreeNode
):
if
self
.
disable
:
return
0
delta
=
0
while
node
!=
self
.
root_node
:
if
node
.
lock_ref
==
0
:
...
...
@@ -179,6 +182,9 @@ class RadixCache(BasePrefixCache):
return
delta
def
dec_lock_ref
(
self
,
node
:
TreeNode
):
if
self
.
disable
:
return
0
delta
=
0
while
node
!=
self
.
root_node
:
if
node
.
lock_ref
==
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment