test_schedule_policy.py 2.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import unittest

from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.schedule_policy import (
    CacheAgnosticPolicy,
    CacheAwarePolicy,
    SchedulePolicy,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.sampling.sampling_params import SamplingParams
11
from sglang.test.test_utils import CustomTestCase
12
13


14
class TestSchedulePolicy(CustomTestCase):
15
16
17
18
19

    def setUp(self):
        self.tree_cache = RadixCache(None, None, False)

    def test_init_with_cache_aware_policy(self):
20
21
22
        policy = SchedulePolicy(
            policy="lpm", tree_cache=self.tree_cache, enable_hierarchical_cache=True
        )
23
24
25
        self.assertEqual(policy.policy, CacheAwarePolicy.LPM)

    def test_init_with_cache_agnostic_policy(self):
26
27
28
        policy = SchedulePolicy(
            policy="fcfs", tree_cache=self.tree_cache, enable_hierarchical_cache=True
        )
29
30
31
32
        self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)

    def test_init_with_unknown_policy(self):
        with self.assertRaises(ValueError):
33
34
35
36
37
            SchedulePolicy(
                policy="invalid",
                tree_cache=self.tree_cache,
                enable_hierarchical_cache=True,
            )
38
39

    def test_init_with_disabled_cache(self):
40
41
42
43
        disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1)
        policy = SchedulePolicy(
            policy="lpm", tree_cache=disabled_tree_cache, enable_hierarchical_cache=True
        )
44
45
46
47
48
49
50
51
52
53
        self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)

    def test_calc_priority_fcfs(self):
        tree_cache = RadixCache(None, None, False)
        waiting_queue = [
            Req(1, "a b", [1, 2], SamplingParams()),
            Req(3, "a b c", [1, 2, 3], SamplingParams()),
            Req(2, "a", [1], SamplingParams()),
        ]

54
55
56
        policy = SchedulePolicy(
            policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True
        )
57
58
59
60
61
62
63
64
65
        policy.calc_priority(waiting_queue)
        # Check if FCFS keeps the original order
        self.assertEqual(waiting_queue[0].rid, 1)
        self.assertEqual(waiting_queue[1].rid, 3)
        self.assertEqual(waiting_queue[2].rid, 2)


if __name__ == "__main__":
    unittest.main()