Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bdb3929d
Unverified
Commit
bdb3929d
authored
Jan 04, 2025
by
libra
Committed by
GitHub
Jan 04, 2025
Browse files
Refactor SchedulePolicy to improve code organization (#2571)
parent
f5d0865b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
211 additions
and
90 deletions
+211
-90
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+159
-90
test/srt/test_schedule_policy.py
test/srt/test_schedule_policy.py
+52
-0
No files found.
python/sglang/srt/managers/schedule_policy.py
View file @
bdb3929d
...
@@ -18,7 +18,7 @@ import random
...
@@ -18,7 +18,7 @@ import random
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Union
import
torch
import
torch
...
@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
...
@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
)
)
class
CacheAwarePolicy
(
Enum
):
"""Scheduling policies that are aware of the tree cache."""
LPM
=
"lpm"
# longest prefix match
DFS_WEIGHT
=
"dfs-weight"
# depth-first search weighting
class
CacheAgnosticPolicy
(
Enum
):
"""Scheduling policies that are not aware of the tree cache."""
FCFS
=
"fcfs"
# first come first serve
LOF
=
"lof"
# longest output first
RANDOM
=
"random"
class
SchedulePolicy
:
class
SchedulePolicy
:
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
Policy
=
Union
[
CacheAwarePolicy
,
CacheAgnosticPolicy
]
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy
=
"fcfs"
self
.
policy
=
policy
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
self
.
policy
=
self
.
_validate_and_adjust_policy
(
policy
,
tree_cache
)
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
# It is used to find the matching prefix for in-batch prefix caching.
# It is used to find the matching prefix for in-batch prefix caching.
...
@@ -64,18 +77,67 @@ class SchedulePolicy:
...
@@ -64,18 +77,67 @@ class SchedulePolicy:
req_to_token_pool
=
None
,
token_to_kv_pool
=
None
,
disable
=
False
req_to_token_pool
=
None
,
token_to_kv_pool
=
None
,
disable
=
False
)
)
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
])
->
bool
:
if
len
(
waiting_queue
)
>
128
and
self
.
policy
==
"lpm"
:
policy
=
self
.
_determine_active_policy
(
waiting_queue
)
# Turn off the expensive prefix matching and sorting when the #queue is large.
policy
=
"fcfs"
else
:
policy
=
self
.
policy
# Compute matched prefix length
prefix_computed
=
False
prefix_computed
=
False
if
policy
==
"lpm"
or
policy
==
"dfs-weight"
:
if
isinstance
(
policy
,
CacheAwarePolicy
):
# rid to deprioritize in the current run for in-batch prefix caching.
prefix_computed
=
True
temporary_deprioritized
=
set
()
temporary_deprioritized
=
self
.
_compute_prefix_matches
(
waiting_queue
,
policy
)
if
policy
==
CacheAwarePolicy
.
LPM
:
SchedulePolicy
.
_sort_by_longest_prefix
(
waiting_queue
,
temporary_deprioritized
)
elif
policy
==
CacheAwarePolicy
.
DFS_WEIGHT
:
SchedulePolicy
.
_sort_by_dfs_weight
(
waiting_queue
,
self
.
tree_cache
)
else
:
raise
ValueError
(
f
"Unknown CacheAware Policy:
{
policy
=
}
"
)
else
:
if
policy
==
CacheAgnosticPolicy
.
FCFS
:
pass
elif
policy
==
CacheAgnosticPolicy
.
LOF
:
SchedulePolicy
.
_sort_by_longest_output
(
waiting_queue
)
elif
policy
==
CacheAgnosticPolicy
.
RANDOM
:
SchedulePolicy
.
_sort_randomly
(
waiting_queue
)
else
:
raise
ValueError
(
f
"Unknown CacheAgnostic Policy:
{
policy
=
}
"
)
return
prefix_computed
def
_determine_active_policy
(
self
,
waiting_queue
:
List
[
Req
])
->
Policy
:
if
len
(
waiting_queue
)
>
128
and
self
.
policy
==
CacheAwarePolicy
.
LPM
:
# Turn off the expensive prefix matching and sorting when the #queue is large.
return
CacheAgnosticPolicy
.
FCFS
return
self
.
policy
def
_validate_and_adjust_policy
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
)
->
Policy
:
"""
Validates the policy and adjusts it if necessary based on tree cache settings.
"""
try
:
policy_enum
=
CacheAwarePolicy
(
policy
)
if
tree_cache
.
disable
:
# If tree_cache is disabled, using CacheAgnosticPolicy policy
return
CacheAgnosticPolicy
.
FCFS
return
policy_enum
except
ValueError
:
try
:
return
CacheAgnosticPolicy
(
policy
)
except
ValueError
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
def
_compute_prefix_matches
(
self
,
waiting_queue
:
List
[
Req
],
policy
:
CacheAwarePolicy
)
->
Set
[
int
]:
"""
Computes and caches the matching prefixes for requests in the waiting queue,
and handles in-batch prefix caching logic.
"""
temporary_deprioritized
:
Set
[
int
]
=
set
()
self
.
waiting_queue_radix_tree
.
reset
()
self
.
waiting_queue_radix_tree
.
reset
()
for
r
in
waiting_queue
:
for
r
in
waiting_queue
:
...
@@ -109,11 +171,13 @@ class SchedulePolicy:
...
@@ -109,11 +171,13 @@ class SchedulePolicy:
self
.
waiting_queue_radix_tree
.
insert
(
self
.
waiting_queue_radix_tree
.
insert
(
prefix_ids
,
torch
.
empty
(
len
(
prefix_ids
),
dtype
=
torch
.
bool
)
prefix_ids
,
torch
.
empty
(
len
(
prefix_ids
),
dtype
=
torch
.
bool
)
)
)
return
temporary_deprioritized
prefix_computed
=
True
@
staticmethod
def
_sort_by_longest_prefix
(
if
policy
==
"lpm"
:
waiting_queue
:
List
[
Req
],
temporary_deprioritized
:
Set
[
int
]
# Longest Prefix Match
)
->
None
:
"""Sorts the waiting queue based on the longest prefix match."""
waiting_queue
.
sort
(
waiting_queue
.
sort
(
key
=
lambda
r
:
(
key
=
lambda
r
:
(
-
len
(
r
.
prefix_indices
)
-
len
(
r
.
prefix_indices
)
...
@@ -121,16 +185,12 @@ class SchedulePolicy:
...
@@ -121,16 +185,12 @@ class SchedulePolicy:
else
float
(
"inf"
)
else
float
(
"inf"
)
)
)
)
)
elif
policy
==
"fcfs"
:
# first come first serve
@
staticmethod
pass
def
_sort_by_dfs_weight
(
elif
policy
==
"lof"
:
waiting_queue
:
List
[
Req
],
tree_cache
:
BasePrefixCache
# longest output first
)
->
None
:
waiting_queue
.
sort
(
key
=
lambda
x
:
-
x
.
sampling_params
.
max_new_tokens
)
"""Sorts the waiting queue based on a depth-first search weighting."""
elif
policy
==
"random"
:
random
.
shuffle
(
waiting_queue
)
elif
policy
==
"dfs-weight"
:
# Experimental policy based on custom weights
last_node_to_reqs
=
defaultdict
(
list
)
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
waiting_queue
:
for
req
in
waiting_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
...
@@ -138,36 +198,45 @@ class SchedulePolicy:
...
@@ -138,36 +198,45 @@ class SchedulePolicy:
node_to_weight
=
defaultdict
(
int
)
node_to_weight
=
defaultdict
(
int
)
for
node
in
last_node_to_reqs
:
for
node
in
last_node_to_reqs
:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
self
.
calc_weight
(
self
.
tree_cache
.
root_node
,
node_to_weight
)
SchedulePolicy
.
_
calc_weight
(
tree_cache
.
root_node
,
node_to_weight
)
waiting_queue
.
clear
()
waiting_queue
.
clear
()
self
.
get_dfs_priority
(
SchedulePolicy
.
_
get_dfs_priority
(
self
.
tree_cache
.
root_node
,
tree_cache
.
root_node
,
node_to_weight
,
node_to_weight
,
last_node_to_reqs
,
last_node_to_reqs
,
waiting_queue
,
waiting_queue
,
)
)
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
return
prefix_computed
@
staticmethod
def
_sort_by_longest_output
(
waiting_queue
:
List
[
Req
])
->
None
:
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
waiting_queue
.
sort
(
key
=
lambda
x
:
-
x
.
sampling_params
.
max_new_tokens
)
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
@
staticmethod
def
_sort_randomly
(
waiting_queue
:
List
[
Req
])
->
None
:
"""Shuffles the waiting queue randomly."""
random
.
shuffle
(
waiting_queue
)
@
staticmethod
def
_calc_weight
(
cur_node
:
TreeNode
,
node_to_weight
:
Dict
[
TreeNode
,
int
])
->
None
:
for
child
in
cur_node
.
children
.
values
():
for
child
in
cur_node
.
children
.
values
():
self
.
calc_weight
(
child
,
node_to_weight
)
SchedulePolicy
.
_
calc_weight
(
child
,
node_to_weight
)
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
def
get_dfs_priority
(
@
staticmethod
self
,
def
_get_dfs_priority
(
cur_node
:
TreeNode
,
cur_node
:
TreeNode
,
node_to_priority
:
Dict
[
TreeNode
,
int
],
node_to_priority
:
Dict
[
TreeNode
,
int
],
last_node_to_reqs
:
Dict
[
TreeNode
,
List
[
Req
]],
last_node_to_reqs
:
Dict
[
TreeNode
,
List
[
Req
]],
q
:
List
,
q
:
List
,
):
)
->
None
:
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
for
child
in
childs
:
self
.
get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
SchedulePolicy
.
_get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
q
.
extend
(
last_node_to_reqs
[
cur_node
])
q
.
extend
(
last_node_to_reqs
[
cur_node
])
...
...
test/srt/test_schedule_policy.py
0 → 100644
View file @
bdb3929d
import
unittest
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_policy
import
(
CacheAgnosticPolicy
,
CacheAwarePolicy
,
SchedulePolicy
,
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
from
sglang.srt.sampling.sampling_params
import
SamplingParams
class
TestSchedulePolicy
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
tree_cache
=
RadixCache
(
None
,
None
,
False
)
def
test_init_with_cache_aware_policy
(
self
):
policy
=
SchedulePolicy
(
policy
=
"lpm"
,
tree_cache
=
self
.
tree_cache
)
self
.
assertEqual
(
policy
.
policy
,
CacheAwarePolicy
.
LPM
)
def
test_init_with_cache_agnostic_policy
(
self
):
policy
=
SchedulePolicy
(
policy
=
"fcfs"
,
tree_cache
=
self
.
tree_cache
)
self
.
assertEqual
(
policy
.
policy
,
CacheAgnosticPolicy
.
FCFS
)
def
test_init_with_unknown_policy
(
self
):
with
self
.
assertRaises
(
ValueError
):
SchedulePolicy
(
policy
=
"invalid"
,
tree_cache
=
self
.
tree_cache
)
def
test_init_with_disabled_cache
(
self
):
disabled_tree_cache
=
RadixCache
(
None
,
None
,
disable
=
True
)
policy
=
SchedulePolicy
(
policy
=
"lpm"
,
tree_cache
=
disabled_tree_cache
)
self
.
assertEqual
(
policy
.
policy
,
CacheAgnosticPolicy
.
FCFS
)
def
test_calc_priority_fcfs
(
self
):
tree_cache
=
RadixCache
(
None
,
None
,
False
)
waiting_queue
=
[
Req
(
1
,
"a b"
,
[
1
,
2
],
SamplingParams
()),
Req
(
3
,
"a b c"
,
[
1
,
2
,
3
],
SamplingParams
()),
Req
(
2
,
"a"
,
[
1
],
SamplingParams
()),
]
policy
=
SchedulePolicy
(
policy
=
"fcfs"
,
tree_cache
=
tree_cache
)
policy
.
calc_priority
(
waiting_queue
)
# Check if FCFS keeps the original order
self
.
assertEqual
(
waiting_queue
[
0
].
rid
,
1
)
self
.
assertEqual
(
waiting_queue
[
1
].
rid
,
3
)
self
.
assertEqual
(
waiting_queue
[
2
].
rid
,
2
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment