Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
62757db6
"vscode:/vscode.git/clone" did not exist on "ba6fd6eb30de97370f06f5804d9cc0e10b5718b5"
Unverified
Commit
62757db6
authored
Aug 09, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 09, 2024
Browse files
Reduce the overhead when cache is disabled (#1010)
parent
73fa2d49
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
43 deletions
+35
-43
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+20
-25
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+4
-18
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+6
-0
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
62757db6
...
...
@@ -18,44 +18,40 @@ limitations under the License.
import
random
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
List
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
class
PolicyScheduler
:
def
__init__
(
self
,
policy
,
max_running_seqs
,
max_prefill_num_tokens
,
max_total_num_tokens
,
tree_cache
,
):
if
tree_cache
.
disable
and
policy
==
"lpm"
:
# LMP is meaningless when the tree cache is disabled.
def
__init__
(
self
,
policy
,
tree_cache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy
=
"fcfs"
self
.
policy
=
policy
self
.
max_running_seqs
=
max_running_seqs
self
.
max_prefill_num_tokens
=
max_prefill_num_tokens
self
.
max_total_num_tokens
=
max_total_num_tokens
self
.
tree_cache
=
tree_cache
def
get_priority_queue
(
self
,
waiting_queue
):
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
if
self
.
policy
in
[
"lpm"
,
"dfs-weight"
]:
# Compute matched prefix length
for
r
in
waiting_queue
:
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
r
.
adjust_max_prefix_ids
()
)
if
self
.
policy
==
"lpm"
:
#
l
ongest
p
refix
m
atch
#
L
ongest
P
refix
M
atch
waiting_queue
.
sort
(
key
=
lambda
x
:
-
len
(
x
.
prefix_indices
))
return
waiting_queue
elif
self
.
policy
==
"fcfs"
:
# first come first serve
return
waiting_queue
pass
elif
self
.
policy
==
"lof"
:
# longest output first
waiting_queue
.
sort
(
key
=
lambda
x
:
-
x
.
sampling_params
.
max_new_tokens
)
return
waiting_queue
elif
self
.
policy
==
"random"
:
random
.
shuffle
(
waiting_queue
)
return
waiting_queue
elif
self
.
policy
==
"dfs-weight"
:
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
waiting_queue
:
...
...
@@ -66,12 +62,13 @@ class PolicyScheduler:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
self
.
calc_weight
(
self
.
tree_cache
.
root_node
,
node_to_weight
)
q
=
[]
waiting_queue
.
clear
()
self
.
get_dfs_priority
(
self
.
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
q
self
.
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
waiting_queue
,
)
assert
len
(
q
)
==
len
(
waiting_queue
)
return
q
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
...
...
@@ -139,8 +136,6 @@ class PrefillAdder:
self
.
log_input_tokens
+=
extend_input_len
def
add_inflight_req
(
self
,
req
:
Req
):
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
req
.
prefix_indices
)
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
input_ids
=
req
.
input_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
62757db6
...
...
@@ -164,7 +164,12 @@ class Req:
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
self
):
self
.
input_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
extend_input_len
=
len
(
self
.
input_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
self
.
input_ids
=
self
.
origin_input_ids
+
self
.
output_ids
input_len
=
len
(
self
.
input_ids
)
max_prefix_len
=
input_len
...
...
python/sglang/srt/managers/tp_worker.py
View file @
62757db6
...
...
@@ -165,13 +165,7 @@ class ModelTpServer:
disable
=
server_args
.
disable_radix_cache
,
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
PolicyScheduler
(
self
.
schedule_policy
,
self
.
max_running_requests
,
self
.
max_prefill_tokens
,
self
.
max_total_num_tokens
,
self
.
tree_cache
,
)
self
.
scheduler
=
PolicyScheduler
(
self
.
schedule_policy
,
self
.
tree_cache
)
self
.
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
self
.
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
...
...
@@ -373,17 +367,8 @@ class ModelTpServer:
if
running_bs
>=
self
.
max_running_requests
:
return
None
# Compute matched prefix length
for
req
in
self
.
waiting_queue
:
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
# NOTE: the prefix_indices must always be aligned with last_node
req
.
prefix_indices
,
req
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
req
.
rid
,
key
=
req
.
adjust_max_prefix_ids
()
)
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
req
.
prefix_indices
)
# Get priority queue
self
.
waiting_queue
=
self
.
scheduler
.
get
_priority
_queue
(
self
.
waiting_queue
)
self
.
scheduler
.
calc
_priority
(
self
.
waiting_queue
)
adder
=
PrefillAdder
(
self
.
tree_cache
,
...
...
@@ -397,12 +382,13 @@ class ModelTpServer:
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
()
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
for
req
in
self
.
waiting_queue
:
req
.
init_next_round_input
()
res
=
adder
.
add_one_req
(
req
)
if
(
not
res
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
62757db6
...
...
@@ -169,6 +169,9 @@ class RadixCache(BasePrefixCache):
heapq
.
heappush
(
leaves
,
x
.
parent
)
def
inc_lock_ref
(
self
,
node
:
TreeNode
):
if
self
.
disable
:
return
0
delta
=
0
while
node
!=
self
.
root_node
:
if
node
.
lock_ref
==
0
:
...
...
@@ -179,6 +182,9 @@ class RadixCache(BasePrefixCache):
return
delta
def
dec_lock_ref
(
self
,
node
:
TreeNode
):
if
self
.
disable
:
return
0
delta
=
0
while
node
!=
self
.
root_node
:
if
node
.
lock_ref
==
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment