Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bdb3929d
Unverified
Commit
bdb3929d
authored
Jan 04, 2025
by
libra
Committed by
GitHub
Jan 04, 2025
Browse files
Refactor SchedulePolicy to improve code organization (#2571)
parent
f5d0865b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
211 additions
and
90 deletions
+211
-90
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+159
-90
test/srt/test_schedule_policy.py
test/srt/test_schedule_policy.py
+52
-0
No files found.
python/sglang/srt/managers/schedule_policy.py
View file @
bdb3929d
...
...
@@ -18,7 +18,7 @@ import random
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
enum
import
Enum
,
auto
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Union
import
torch
...
...
@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
)
class
CacheAwarePolicy
(
Enum
):
"""Scheduling policies that are aware of the tree cache."""
LPM
=
"lpm"
# longest prefix match
DFS_WEIGHT
=
"dfs-weight"
# depth-first search weighting
class
CacheAgnosticPolicy
(
Enum
):
"""Scheduling policies that are not aware of the tree cache."""
FCFS
=
"fcfs"
# first come first serve
LOF
=
"lof"
# longest output first
RANDOM
=
"random"
class
SchedulePolicy
:
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy
=
"fcfs"
Policy
=
Union
[
CacheAwarePolicy
,
CacheAgnosticPolicy
]
self
.
policy
=
policy
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
self
.
policy
=
self
.
_validate_and_adjust_policy
(
policy
,
tree_cache
)
self
.
tree_cache
=
tree_cache
# It is used to find the matching prefix for in-batch prefix caching.
...
...
@@ -64,110 +77,166 @@ class SchedulePolicy:
req_to_token_pool
=
None
,
token_to_kv_pool
=
None
,
disable
=
False
)
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
if
len
(
waiting_queue
)
>
128
and
self
.
policy
==
"lpm"
:
# Turn off the expensive prefix matching and sorting when the #queue is large.
policy
=
"fcfs"
else
:
policy
=
self
.
policy
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
])
->
bool
:
policy
=
self
.
_determine_active_policy
(
waiting_queue
)
# Compute matched prefix length
prefix_computed
=
False
if
policy
==
"lpm"
or
policy
==
"dfs-weight"
:
# rid to deprioritize in the current run for in-batch prefix caching.
temporary_deprioritized
=
set
()
self
.
waiting_queue_radix_tree
.
reset
()
for
r
in
waiting_queue
:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
if
isinstance
(
policy
,
CacheAwarePolicy
):
prefix_computed
=
True
temporary_deprioritized
=
self
.
_compute_prefix_matches
(
waiting_queue
,
policy
)
if
policy
==
CacheAwarePolicy
.
LPM
:
SchedulePolicy
.
_sort_by_longest_prefix
(
waiting_queue
,
temporary_deprioritized
)
elif
policy
==
CacheAwarePolicy
.
DFS_WEIGHT
:
SchedulePolicy
.
_sort_by_dfs_weight
(
waiting_queue
,
self
.
tree_cache
)
else
:
raise
ValueError
(
f
"Unknown CacheAware Policy:
{
policy
=
}
"
)
else
:
if
policy
==
CacheAgnosticPolicy
.
FCFS
:
pass
elif
policy
==
CacheAgnosticPolicy
.
LOF
:
SchedulePolicy
.
_sort_by_longest_output
(
waiting_queue
)
elif
policy
==
CacheAgnosticPolicy
.
RANDOM
:
SchedulePolicy
.
_sort_randomly
(
waiting_queue
)
else
:
raise
ValueError
(
f
"Unknown CacheAgnostic Policy:
{
policy
=
}
"
)
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if
len
(
r
.
prefix_indices
)
<=
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD
:
in_batch_matching_prefixes
,
_
=
(
self
.
waiting_queue_radix_tree
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
)
if
(
len
(
in_batch_matching_prefixes
)
>=
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized
.
add
(
r
.
rid
)
else
:
# Insert with a dummy key
self
.
waiting_queue_radix_tree
.
insert
(
prefix_ids
,
torch
.
empty
(
len
(
prefix_ids
),
dtype
=
torch
.
bool
)
)
return
prefix_computed
prefix_computed
=
True
def
_determine_active_policy
(
self
,
waiting_queue
:
List
[
Req
])
->
Policy
:
if
len
(
waiting_queue
)
>
128
and
self
.
policy
==
CacheAwarePolicy
.
LPM
:
# Turn off the expensive prefix matching and sorting when the #queue is large.
return
CacheAgnosticPolicy
.
FCFS
return
self
.
policy
def
_validate_and_adjust_policy
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
)
->
Policy
:
"""
Validates the policy and adjusts it if necessary based on tree cache settings.
"""
try
:
policy_enum
=
CacheAwarePolicy
(
policy
)
if
tree_cache
.
disable
:
# If tree_cache is disabled, using CacheAgnosticPolicy policy
return
CacheAgnosticPolicy
.
FCFS
return
policy_enum
except
ValueError
:
try
:
return
CacheAgnosticPolicy
(
policy
)
except
ValueError
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
def
_compute_prefix_matches
(
self
,
waiting_queue
:
List
[
Req
],
policy
:
CacheAwarePolicy
)
->
Set
[
int
]:
"""
Computes and caches the matching prefixes for requests in the waiting queue,
and handles in-batch prefix caching logic.
"""
temporary_deprioritized
:
Set
[
int
]
=
set
()
self
.
waiting_queue_radix_tree
.
reset
()
for
r
in
waiting_queue
:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
if
policy
==
"lpm"
:
# Longest Prefix Match
waiting_queue
.
sort
(
key
=
lambda
r
:
(
-
len
(
r
.
prefix_indices
)
if
r
.
rid
not
in
temporary_deprioritized
else
float
(
"inf"
)
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if
len
(
r
.
prefix_indices
)
<=
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD
:
in_batch_matching_prefixes
,
_
=
(
self
.
waiting_queue_radix_tree
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
)
if
(
len
(
in_batch_matching_prefixes
)
>=
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized
.
add
(
r
.
rid
)
else
:
# Insert with a dummy key
self
.
waiting_queue_radix_tree
.
insert
(
prefix_ids
,
torch
.
empty
(
len
(
prefix_ids
),
dtype
=
torch
.
bool
)
)
return
temporary_deprioritized
@
staticmethod
def
_sort_by_longest_prefix
(
waiting_queue
:
List
[
Req
],
temporary_deprioritized
:
Set
[
int
]
)
->
None
:
"""Sorts the waiting queue based on the longest prefix match."""
waiting_queue
.
sort
(
key
=
lambda
r
:
(
-
len
(
r
.
prefix_indices
)
if
r
.
rid
not
in
temporary_deprioritized
else
float
(
"inf"
)
)
elif
policy
==
"fcfs"
:
# first come first serve
pass
elif
policy
==
"lof"
:
# longest output first
waiting_queue
.
sort
(
key
=
lambda
x
:
-
x
.
sampling_params
.
max_new_tokens
)
elif
policy
==
"random"
:
random
.
shuffle
(
waiting_queue
)
elif
policy
==
"dfs-weight"
:
# Experimental policy based on custom weights
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
waiting_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
node_to_weight
=
defaultdict
(
int
)
for
node
in
last_node_to_reqs
:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
self
.
calc_weight
(
self
.
tree_cache
.
root_node
,
node_to_weight
)
waiting_queue
.
clear
()
self
.
get_dfs_priority
(
self
.
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
waiting_queue
,
)
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
)
return
prefix_computed
@
staticmethod
def
_sort_by_dfs_weight
(
waiting_queue
:
List
[
Req
],
tree_cache
:
BasePrefixCache
)
->
None
:
"""Sorts the waiting queue based on a depth-first search weighting."""
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
waiting_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
node_to_weight
=
defaultdict
(
int
)
for
node
in
last_node_to_reqs
:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
SchedulePolicy
.
_calc_weight
(
tree_cache
.
root_node
,
node_to_weight
)
waiting_queue
.
clear
()
SchedulePolicy
.
_get_dfs_priority
(
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
waiting_queue
,
)
@
staticmethod
def
_sort_by_longest_output
(
waiting_queue
:
List
[
Req
])
->
None
:
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
waiting_queue
.
sort
(
key
=
lambda
x
:
-
x
.
sampling_params
.
max_new_tokens
)
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
@
staticmethod
def
_sort_randomly
(
waiting_queue
:
List
[
Req
])
->
None
:
"""Shuffles the waiting queue randomly."""
random
.
shuffle
(
waiting_queue
)
@
staticmethod
def
_calc_weight
(
cur_node
:
TreeNode
,
node_to_weight
:
Dict
[
TreeNode
,
int
])
->
None
:
for
child
in
cur_node
.
children
.
values
():
self
.
calc_weight
(
child
,
node_to_weight
)
SchedulePolicy
.
_
calc_weight
(
child
,
node_to_weight
)
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
def
get_dfs_priority
(
self
,
@
staticmethod
def
_get_dfs_priority
(
cur_node
:
TreeNode
,
node_to_priority
:
Dict
[
TreeNode
,
int
],
last_node_to_reqs
:
Dict
[
TreeNode
,
List
[
Req
]],
q
:
List
,
):
)
->
None
:
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
self
.
get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
SchedulePolicy
.
_get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
q
.
extend
(
last_node_to_reqs
[
cur_node
])
...
...
test/srt/test_schedule_policy.py
0 → 100644
View file @
bdb3929d
import
unittest
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_policy
import
(
CacheAgnosticPolicy
,
CacheAwarePolicy
,
SchedulePolicy
,
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
from
sglang.srt.sampling.sampling_params
import
SamplingParams
class
TestSchedulePolicy
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
tree_cache
=
RadixCache
(
None
,
None
,
False
)
def
test_init_with_cache_aware_policy
(
self
):
policy
=
SchedulePolicy
(
policy
=
"lpm"
,
tree_cache
=
self
.
tree_cache
)
self
.
assertEqual
(
policy
.
policy
,
CacheAwarePolicy
.
LPM
)
def
test_init_with_cache_agnostic_policy
(
self
):
policy
=
SchedulePolicy
(
policy
=
"fcfs"
,
tree_cache
=
self
.
tree_cache
)
self
.
assertEqual
(
policy
.
policy
,
CacheAgnosticPolicy
.
FCFS
)
def
test_init_with_unknown_policy
(
self
):
with
self
.
assertRaises
(
ValueError
):
SchedulePolicy
(
policy
=
"invalid"
,
tree_cache
=
self
.
tree_cache
)
def
test_init_with_disabled_cache
(
self
):
disabled_tree_cache
=
RadixCache
(
None
,
None
,
disable
=
True
)
policy
=
SchedulePolicy
(
policy
=
"lpm"
,
tree_cache
=
disabled_tree_cache
)
self
.
assertEqual
(
policy
.
policy
,
CacheAgnosticPolicy
.
FCFS
)
def
test_calc_priority_fcfs
(
self
):
tree_cache
=
RadixCache
(
None
,
None
,
False
)
waiting_queue
=
[
Req
(
1
,
"a b"
,
[
1
,
2
],
SamplingParams
()),
Req
(
3
,
"a b c"
,
[
1
,
2
,
3
],
SamplingParams
()),
Req
(
2
,
"a"
,
[
1
],
SamplingParams
()),
]
policy
=
SchedulePolicy
(
policy
=
"fcfs"
,
tree_cache
=
tree_cache
)
policy
.
calc_priority
(
waiting_queue
)
# Check if FCFS keeps the original order
self
.
assertEqual
(
waiting_queue
[
0
].
rid
,
1
)
self
.
assertEqual
(
waiting_queue
[
1
].
rid
,
3
)
self
.
assertEqual
(
waiting_queue
[
2
].
rid
,
2
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment