Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bdb3929d
"vscode:/vscode.git/clone" did not exist on "d340ea3aa861a936430b9f37caafa2788bfae185"
Unverified
Commit
bdb3929d
authored
Jan 04, 2025
by
libra
Committed by
GitHub
Jan 04, 2025
Browse files
Refactor SchedulePolicy to improve code organization (#2571)
parent
f5d0865b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
211 additions
and
90 deletions
+211
-90
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+159
-90
test/srt/test_schedule_policy.py
test/srt/test_schedule_policy.py
+52
-0
No files found.
python/sglang/srt/managers/schedule_policy.py
View file @
bdb3929d
...
@@ -18,7 +18,7 @@ import random
...
@@ -18,7 +18,7 @@ import random
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Union
import
torch
import
torch
...
@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
...
@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
)
)
class
CacheAwarePolicy
(
Enum
):
"""Scheduling policies that are aware of the tree cache."""
LPM
=
"lpm"
# longest prefix match
DFS_WEIGHT
=
"dfs-weight"
# depth-first search weighting
class
CacheAgnosticPolicy
(
Enum
):
"""Scheduling policies that are not aware of the tree cache."""
FCFS
=
"fcfs"
# first come first serve
LOF
=
"lof"
# longest output first
RANDOM
=
"random"
class
SchedulePolicy
:
class
SchedulePolicy
:
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
Policy
=
Union
[
CacheAwarePolicy
,
CacheAgnosticPolicy
]
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy
=
"fcfs"
self
.
policy
=
policy
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
self
.
policy
=
self
.
_validate_and_adjust_policy
(
policy
,
tree_cache
)
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
# It is used to find the matching prefix for in-batch prefix caching.
# It is used to find the matching prefix for in-batch prefix caching.
...
@@ -64,110 +77,166 @@ class SchedulePolicy:
...
@@ -64,110 +77,166 @@ class SchedulePolicy:
req_to_token_pool
=
None
,
token_to_kv_pool
=
None
,
disable
=
False
req_to_token_pool
=
None
,
token_to_kv_pool
=
None
,
disable
=
False
)
)
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
])
->
bool
:
if
len
(
waiting_queue
)
>
128
and
self
.
policy
==
"lpm"
:
policy
=
self
.
_determine_active_policy
(
waiting_queue
)
# Turn off the expensive prefix matching and sorting when the #queue is large.
policy
=
"fcfs"
else
:
policy
=
self
.
policy
# Compute matched prefix length
prefix_computed
=
False
prefix_computed
=
False
if
policy
==
"lpm"
or
policy
==
"dfs-weight"
:
if
isinstance
(
policy
,
CacheAwarePolicy
):
# rid to deprioritize in the current run for in-batch prefix caching.
prefix_computed
=
True
temporary_deprioritized
=
set
()
temporary_deprioritized
=
self
.
_compute_prefix_matches
(
self
.
waiting_queue_radix_tree
.
reset
()
waiting_queue
,
policy
)
for
r
in
waiting_queue
:
if
policy
==
CacheAwarePolicy
.
LPM
:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
SchedulePolicy
.
_sort_by_longest_prefix
(
waiting_queue
,
temporary_deprioritized
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
)
elif
policy
==
CacheAwarePolicy
.
DFS_WEIGHT
:
SchedulePolicy
.
_sort_by_dfs_weight
(
waiting_queue
,
self
.
tree_cache
)
else
:
raise
ValueError
(
f
"Unknown CacheAware Policy:
{
policy
=
}
"
)
else
:
if
policy
==
CacheAgnosticPolicy
.
FCFS
:
pass
elif
policy
==
CacheAgnosticPolicy
.
LOF
:
SchedulePolicy
.
_sort_by_longest_output
(
waiting_queue
)
elif
policy
==
CacheAgnosticPolicy
.
RANDOM
:
SchedulePolicy
.
_sort_randomly
(
waiting_queue
)
else
:
raise
ValueError
(
f
"Unknown CacheAgnostic Policy:
{
policy
=
}
"
)
# NOTE(sang): This logic is for in-batch prefix caching;
return
prefix_computed
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if
len
(
r
.
prefix_indices
)
<=
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD
:
in_batch_matching_prefixes
,
_
=
(
self
.
waiting_queue_radix_tree
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
)
if
(
len
(
in_batch_matching_prefixes
)
>=
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized
.
add
(
r
.
rid
)
else
:
# Insert with a dummy key
self
.
waiting_queue_radix_tree
.
insert
(
prefix_ids
,
torch
.
empty
(
len
(
prefix_ids
),
dtype
=
torch
.
bool
)
)
prefix_computed
=
True
def
_determine_active_policy
(
self
,
waiting_queue
:
List
[
Req
])
->
Policy
:
if
len
(
waiting_queue
)
>
128
and
self
.
policy
==
CacheAwarePolicy
.
LPM
:
# Turn off the expensive prefix matching and sorting when the #queue is large.
return
CacheAgnosticPolicy
.
FCFS
return
self
.
policy
def
_validate_and_adjust_policy
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
)
->
Policy
:
"""
Validates the policy and adjusts it if necessary based on tree cache settings.
"""
try
:
policy_enum
=
CacheAwarePolicy
(
policy
)
if
tree_cache
.
disable
:
# If tree_cache is disabled, using CacheAgnosticPolicy policy
return
CacheAgnosticPolicy
.
FCFS
return
policy_enum
except
ValueError
:
try
:
return
CacheAgnosticPolicy
(
policy
)
except
ValueError
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
def
_compute_prefix_matches
(
self
,
waiting_queue
:
List
[
Req
],
policy
:
CacheAwarePolicy
)
->
Set
[
int
]:
"""
Computes and caches the matching prefixes for requests in the waiting queue,
and handles in-batch prefix caching logic.
"""
temporary_deprioritized
:
Set
[
int
]
=
set
()
self
.
waiting_queue_radix_tree
.
reset
()
for
r
in
waiting_queue
:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
if
policy
==
"lpm"
:
# NOTE(sang): This logic is for in-batch prefix caching;
# Longest Prefix Match
# If there are more than 1 request that have small matching prefix from
waiting_queue
.
sort
(
# existing cache, but all those requests share the same prefix, we prefer
key
=
lambda
r
:
(
# to schedule only one of them so that we can increase the cache hit rate.
-
len
(
r
.
prefix_indices
)
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
if
r
.
rid
not
in
temporary_deprioritized
# threshold means we cannot use in-batch prefix caching for short prefixes.
else
float
(
"inf"
)
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if
len
(
r
.
prefix_indices
)
<=
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD
:
in_batch_matching_prefixes
,
_
=
(
self
.
waiting_queue_radix_tree
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
)
)
if
(
len
(
in_batch_matching_prefixes
)
>=
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized
.
add
(
r
.
rid
)
else
:
# Insert with a dummy key
self
.
waiting_queue_radix_tree
.
insert
(
prefix_ids
,
torch
.
empty
(
len
(
prefix_ids
),
dtype
=
torch
.
bool
)
)
return
temporary_deprioritized
@
staticmethod
def
_sort_by_longest_prefix
(
waiting_queue
:
List
[
Req
],
temporary_deprioritized
:
Set
[
int
]
)
->
None
:
"""Sorts the waiting queue based on the longest prefix match."""
waiting_queue
.
sort
(
key
=
lambda
r
:
(
-
len
(
r
.
prefix_indices
)
if
r
.
rid
not
in
temporary_deprioritized
else
float
(
"inf"
)
)
)
elif
policy
==
"fcfs"
:
)
# first come first serve
pass
elif
policy
==
"lof"
:
# longest output first
waiting_queue
.
sort
(
key
=
lambda
x
:
-
x
.
sampling_params
.
max_new_tokens
)
elif
policy
==
"random"
:
random
.
shuffle
(
waiting_queue
)
elif
policy
==
"dfs-weight"
:
# Experimental policy based on custom weights
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
waiting_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
node_to_weight
=
defaultdict
(
int
)
for
node
in
last_node_to_reqs
:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
self
.
calc_weight
(
self
.
tree_cache
.
root_node
,
node_to_weight
)
waiting_queue
.
clear
()
self
.
get_dfs_priority
(
self
.
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
waiting_queue
,
)
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
return
prefix_computed
@
staticmethod
def
_sort_by_dfs_weight
(
waiting_queue
:
List
[
Req
],
tree_cache
:
BasePrefixCache
)
->
None
:
"""Sorts the waiting queue based on a depth-first search weighting."""
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
waiting_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
node_to_weight
=
defaultdict
(
int
)
for
node
in
last_node_to_reqs
:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
SchedulePolicy
.
_calc_weight
(
tree_cache
.
root_node
,
node_to_weight
)
waiting_queue
.
clear
()
SchedulePolicy
.
_get_dfs_priority
(
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
waiting_queue
,
)
@
staticmethod
def
_sort_by_longest_output
(
waiting_queue
:
List
[
Req
])
->
None
:
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
waiting_queue
.
sort
(
key
=
lambda
x
:
-
x
.
sampling_params
.
max_new_tokens
)
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
@
staticmethod
def
_sort_randomly
(
waiting_queue
:
List
[
Req
])
->
None
:
"""Shuffles the waiting queue randomly."""
random
.
shuffle
(
waiting_queue
)
@
staticmethod
def
_calc_weight
(
cur_node
:
TreeNode
,
node_to_weight
:
Dict
[
TreeNode
,
int
])
->
None
:
for
child
in
cur_node
.
children
.
values
():
for
child
in
cur_node
.
children
.
values
():
self
.
calc_weight
(
child
,
node_to_weight
)
SchedulePolicy
.
_
calc_weight
(
child
,
node_to_weight
)
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
def
get_dfs_priority
(
@
staticmethod
self
,
def
_get_dfs_priority
(
cur_node
:
TreeNode
,
cur_node
:
TreeNode
,
node_to_priority
:
Dict
[
TreeNode
,
int
],
node_to_priority
:
Dict
[
TreeNode
,
int
],
last_node_to_reqs
:
Dict
[
TreeNode
,
List
[
Req
]],
last_node_to_reqs
:
Dict
[
TreeNode
,
List
[
Req
]],
q
:
List
,
q
:
List
,
):
)
->
None
:
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
for
child
in
childs
:
self
.
get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
SchedulePolicy
.
_get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
q
.
extend
(
last_node_to_reqs
[
cur_node
])
q
.
extend
(
last_node_to_reqs
[
cur_node
])
...
...
test/srt/test_schedule_policy.py
0 → 100644
View file @
bdb3929d
import
unittest
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_policy
import
(
CacheAgnosticPolicy
,
CacheAwarePolicy
,
SchedulePolicy
,
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
from
sglang.srt.sampling.sampling_params
import
SamplingParams
class
TestSchedulePolicy
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
tree_cache
=
RadixCache
(
None
,
None
,
False
)
def
test_init_with_cache_aware_policy
(
self
):
policy
=
SchedulePolicy
(
policy
=
"lpm"
,
tree_cache
=
self
.
tree_cache
)
self
.
assertEqual
(
policy
.
policy
,
CacheAwarePolicy
.
LPM
)
def
test_init_with_cache_agnostic_policy
(
self
):
policy
=
SchedulePolicy
(
policy
=
"fcfs"
,
tree_cache
=
self
.
tree_cache
)
self
.
assertEqual
(
policy
.
policy
,
CacheAgnosticPolicy
.
FCFS
)
def
test_init_with_unknown_policy
(
self
):
with
self
.
assertRaises
(
ValueError
):
SchedulePolicy
(
policy
=
"invalid"
,
tree_cache
=
self
.
tree_cache
)
def
test_init_with_disabled_cache
(
self
):
disabled_tree_cache
=
RadixCache
(
None
,
None
,
disable
=
True
)
policy
=
SchedulePolicy
(
policy
=
"lpm"
,
tree_cache
=
disabled_tree_cache
)
self
.
assertEqual
(
policy
.
policy
,
CacheAgnosticPolicy
.
FCFS
)
def
test_calc_priority_fcfs
(
self
):
tree_cache
=
RadixCache
(
None
,
None
,
False
)
waiting_queue
=
[
Req
(
1
,
"a b"
,
[
1
,
2
],
SamplingParams
()),
Req
(
3
,
"a b c"
,
[
1
,
2
,
3
],
SamplingParams
()),
Req
(
2
,
"a"
,
[
1
],
SamplingParams
()),
]
policy
=
SchedulePolicy
(
policy
=
"fcfs"
,
tree_cache
=
tree_cache
)
policy
.
calc_priority
(
waiting_queue
)
# Check if FCFS keeps the original order
self
.
assertEqual
(
waiting_queue
[
0
].
rid
,
1
)
self
.
assertEqual
(
waiting_queue
[
1
].
rid
,
3
)
self
.
assertEqual
(
waiting_queue
[
2
].
rid
,
2
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment