Unverified Commit bdb3929d authored by libra's avatar libra Committed by GitHub
Browse files

Refactor SchedulePolicy to improve code organization (#2571)

parent f5d0865b
...@@ -18,7 +18,7 @@ import random ...@@ -18,7 +18,7 @@ import random
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, List, Optional from typing import Dict, List, Optional, Set, Union
import torch import torch
...@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int( ...@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
) )
class CacheAwarePolicy(Enum):
"""Scheduling policies that are aware of the tree cache."""
LPM = "lpm" # longest prefix match
DFS_WEIGHT = "dfs-weight" # depth-first search weighting
class CacheAgnosticPolicy(Enum):
"""Scheduling policies that are not aware of the tree cache."""
FCFS = "fcfs" # first come first serve
LOF = "lof" # longest output first
RANDOM = "random"
class SchedulePolicy: class SchedulePolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache): Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy = "fcfs"
self.policy = policy def __init__(self, policy: str, tree_cache: BasePrefixCache):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache self.tree_cache = tree_cache
# It is used to find the matching prefix for in-batch prefix caching. # It is used to find the matching prefix for in-batch prefix caching.
...@@ -64,18 +77,67 @@ class SchedulePolicy: ...@@ -64,18 +77,67 @@ class SchedulePolicy:
req_to_token_pool=None, token_to_kv_pool=None, disable=False req_to_token_pool=None, token_to_kv_pool=None, disable=False
) )
def calc_priority(self, waiting_queue: List[Req]): def calc_priority(self, waiting_queue: List[Req]) -> bool:
if len(waiting_queue) > 128 and self.policy == "lpm": policy = self._determine_active_policy(waiting_queue)
# Turn off the expensive prefix matching and sorting when the #queue is large.
policy = "fcfs"
else:
policy = self.policy
# Compute matched prefix length
prefix_computed = False prefix_computed = False
if policy == "lpm" or policy == "dfs-weight": if isinstance(policy, CacheAwarePolicy):
# rid to deprioritize in the current run for in-batch prefix caching. prefix_computed = True
temporary_deprioritized = set() temporary_deprioritized = self._compute_prefix_matches(
waiting_queue, policy
)
if policy == CacheAwarePolicy.LPM:
SchedulePolicy._sort_by_longest_prefix(
waiting_queue, temporary_deprioritized
)
elif policy == CacheAwarePolicy.DFS_WEIGHT:
SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache)
else:
raise ValueError(f"Unknown CacheAware Policy: {policy=}")
else:
if policy == CacheAgnosticPolicy.FCFS:
pass
elif policy == CacheAgnosticPolicy.LOF:
SchedulePolicy._sort_by_longest_output(waiting_queue)
elif policy == CacheAgnosticPolicy.RANDOM:
SchedulePolicy._sort_randomly(waiting_queue)
else:
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
return prefix_computed
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
# Turn off the expensive prefix matching and sorting when the #queue is large.
return CacheAgnosticPolicy.FCFS
return self.policy
def _validate_and_adjust_policy(
self, policy: str, tree_cache: BasePrefixCache
) -> Policy:
"""
Validates the policy and adjusts it if necessary based on tree cache settings.
"""
try:
policy_enum = CacheAwarePolicy(policy)
if tree_cache.disable:
# If tree_cache is disabled, using CacheAgnosticPolicy policy
return CacheAgnosticPolicy.FCFS
return policy_enum
except ValueError:
try:
return CacheAgnosticPolicy(policy)
except ValueError:
raise ValueError(f"Unknown schedule_policy: {policy=}")
def _compute_prefix_matches(
self, waiting_queue: List[Req], policy: CacheAwarePolicy
) -> Set[int]:
"""
Computes and caches the matching prefixes for requests in the waiting queue,
and handles in-batch prefix caching logic.
"""
temporary_deprioritized: Set[int] = set()
self.waiting_queue_radix_tree.reset() self.waiting_queue_radix_tree.reset()
for r in waiting_queue: for r in waiting_queue:
...@@ -109,11 +171,13 @@ class SchedulePolicy: ...@@ -109,11 +171,13 @@ class SchedulePolicy:
self.waiting_queue_radix_tree.insert( self.waiting_queue_radix_tree.insert(
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool) prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
) )
return temporary_deprioritized
prefix_computed = True @staticmethod
def _sort_by_longest_prefix(
if policy == "lpm": waiting_queue: List[Req], temporary_deprioritized: Set[int]
# Longest Prefix Match ) -> None:
"""Sorts the waiting queue based on the longest prefix match."""
waiting_queue.sort( waiting_queue.sort(
key=lambda r: ( key=lambda r: (
-len(r.prefix_indices) -len(r.prefix_indices)
...@@ -121,16 +185,12 @@ class SchedulePolicy: ...@@ -121,16 +185,12 @@ class SchedulePolicy:
else float("inf") else float("inf")
) )
) )
elif policy == "fcfs":
# first come first serve @staticmethod
pass def _sort_by_dfs_weight(
elif policy == "lof": waiting_queue: List[Req], tree_cache: BasePrefixCache
# longest output first ) -> None:
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) """Sorts the waiting queue based on a depth-first search weighting."""
elif policy == "random":
random.shuffle(waiting_queue)
elif policy == "dfs-weight":
# Experimental policy based on custom weights
last_node_to_reqs = defaultdict(list) last_node_to_reqs = defaultdict(list)
for req in waiting_queue: for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req) last_node_to_reqs[req.last_node].append(req)
...@@ -138,36 +198,45 @@ class SchedulePolicy: ...@@ -138,36 +198,45 @@ class SchedulePolicy:
node_to_weight = defaultdict(int) node_to_weight = defaultdict(int)
for node in last_node_to_reqs: for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node]) node_to_weight[node] = len(last_node_to_reqs[node])
self.calc_weight(self.tree_cache.root_node, node_to_weight) SchedulePolicy._calc_weight(tree_cache.root_node, node_to_weight)
waiting_queue.clear() waiting_queue.clear()
self.get_dfs_priority( SchedulePolicy._get_dfs_priority(
self.tree_cache.root_node, tree_cache.root_node,
node_to_weight, node_to_weight,
last_node_to_reqs, last_node_to_reqs,
waiting_queue, waiting_queue,
) )
else:
raise ValueError(f"Unknown schedule_policy: {policy=}")
return prefix_computed @staticmethod
def _sort_by_longest_output(waiting_queue: List[Req]) -> None:
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict): @staticmethod
def _sort_randomly(waiting_queue: List[Req]) -> None:
"""Shuffles the waiting queue randomly."""
random.shuffle(waiting_queue)
@staticmethod
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
for child in cur_node.children.values(): for child in cur_node.children.values():
self.calc_weight(child, node_to_weight) SchedulePolicy._calc_weight(child, node_to_weight)
node_to_weight[cur_node] += node_to_weight[child] node_to_weight[cur_node] += node_to_weight[child]
def get_dfs_priority( @staticmethod
self, def _get_dfs_priority(
cur_node: TreeNode, cur_node: TreeNode,
node_to_priority: Dict[TreeNode, int], node_to_priority: Dict[TreeNode, int],
last_node_to_reqs: Dict[TreeNode, List[Req]], last_node_to_reqs: Dict[TreeNode, List[Req]],
q: List, q: List,
): ) -> None:
childs = [child for child in cur_node.children.values()] childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x]) childs.sort(key=lambda x: -node_to_priority[x])
for child in childs: for child in childs:
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) SchedulePolicy._get_dfs_priority(
child, node_to_priority, last_node_to_reqs, q
)
q.extend(last_node_to_reqs[cur_node]) q.extend(last_node_to_reqs[cur_node])
......
import unittest
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.schedule_policy import (
CacheAgnosticPolicy,
CacheAwarePolicy,
SchedulePolicy,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.sampling.sampling_params import SamplingParams
class TestSchedulePolicy(unittest.TestCase):
def setUp(self):
self.tree_cache = RadixCache(None, None, False)
def test_init_with_cache_aware_policy(self):
policy = SchedulePolicy(policy="lpm", tree_cache=self.tree_cache)
self.assertEqual(policy.policy, CacheAwarePolicy.LPM)
def test_init_with_cache_agnostic_policy(self):
policy = SchedulePolicy(policy="fcfs", tree_cache=self.tree_cache)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
def test_init_with_unknown_policy(self):
with self.assertRaises(ValueError):
SchedulePolicy(policy="invalid", tree_cache=self.tree_cache)
def test_init_with_disabled_cache(self):
disabled_tree_cache = RadixCache(None, None, disable=True)
policy = SchedulePolicy(policy="lpm", tree_cache=disabled_tree_cache)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
def test_calc_priority_fcfs(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]
policy = SchedulePolicy(policy="fcfs", tree_cache=tree_cache)
policy.calc_priority(waiting_queue)
# Check if FCFS keeps the original order
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 3)
self.assertEqual(waiting_queue[2].rid, 2)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment