import unittest from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_policy import ( CacheAgnosticPolicy, CacheAwarePolicy, SchedulePolicy, ) from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.sampling.sampling_params import SamplingParams from sglang.test.test_utils import CustomTestCase class TestSchedulePolicy(CustomTestCase): def setUp(self): self.tree_cache = RadixCache(None, None, False) def test_init_with_cache_aware_policy(self): policy = SchedulePolicy( policy="lpm", tree_cache=self.tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=False, schedule_low_priority_values_first=False, ) self.assertEqual(policy.policy, CacheAwarePolicy.LPM) def test_init_with_cache_agnostic_policy(self): policy = SchedulePolicy( policy="fcfs", tree_cache=self.tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=False, schedule_low_priority_values_first=False, ) self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) def test_init_with_unknown_policy(self): with self.assertRaises(ValueError): SchedulePolicy( policy="invalid", tree_cache=self.tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=False, schedule_low_priority_values_first=False, ) def test_init_with_disabled_cache(self): disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1) policy = SchedulePolicy( policy="lpm", tree_cache=disabled_tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=False, schedule_low_priority_values_first=False, ) self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) def test_calc_priority_fcfs(self): tree_cache = RadixCache(None, None, False) waiting_queue = [ Req(1, "a b", [1, 2], SamplingParams()), Req(3, "a b c", [1, 2, 3], SamplingParams()), Req(2, "a", [1], SamplingParams()), ] policy = SchedulePolicy( policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=False, schedule_low_priority_values_first=False, ) policy.calc_priority(waiting_queue) # Check if FCFS keeps the original order self.assertEqual(waiting_queue[0].rid, 1) self.assertEqual(waiting_queue[1].rid, 3) self.assertEqual(waiting_queue[2].rid, 2) def test_calc_priority_priority_enabled_fcfs_scheduling(self): tree_cache = RadixCache(None, None, False) waiting_queue = [ Req(1, "a b", [1, 2], SamplingParams()), Req(3, "a b c", [1, 2, 3], SamplingParams()), Req(2, "a", [1], SamplingParams()), ] waiting_queue[0].priority, waiting_queue[0].queue_time_start = 1, 1 waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1 waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0 policy = SchedulePolicy( policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=True, schedule_low_priority_values_first=False, ) policy.calc_priority(waiting_queue) # Check if priority enabled fcfs ordering is applied. self.assertEqual(waiting_queue[0].rid, 1) self.assertEqual(waiting_queue[1].rid, 2) self.assertEqual(waiting_queue[2].rid, 3) def test_calc_priority_priority_enabled_fcfs_scheduling_with_low_priority_values_first( self, ): tree_cache = RadixCache(None, None, False) waiting_queue = [ Req(1, "a b", [1, 2], SamplingParams()), Req(3, "a b c", [1, 2, 3], SamplingParams()), Req(2, "a", [1], SamplingParams()), ] waiting_queue[0].priority, waiting_queue[0].queue_time_start = -1, 0 waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1 waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0 policy = SchedulePolicy( policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=True, schedule_low_priority_values_first=True, ) policy.calc_priority(waiting_queue) # Check if priority enabled fcfs ordering is applied. self.assertEqual(waiting_queue[0].rid, 1) self.assertEqual(waiting_queue[1].rid, 2) self.assertEqual(waiting_queue[2].rid, 3) def test_calc_priority_longest_output_first_scheduling(self): tree_cache = RadixCache(None, None, False) waiting_queue = [ Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1000)), Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10)), Req(2, "a", [1], SamplingParams(max_new_tokens=100)), ] policy = SchedulePolicy( policy="lof", tree_cache=tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=False, schedule_low_priority_values_first=False, ) policy.calc_priority(waiting_queue) # Check if priority enabled fcfs ordering is applied. self.assertEqual(waiting_queue[0].rid, 1) self.assertEqual(waiting_queue[1].rid, 2) self.assertEqual(waiting_queue[2].rid, 3) def test_calc_priority_priority_enabled_longest_output_first_scheduling(self): tree_cache = RadixCache(None, None, False) waiting_queue = [ Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=1), Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=0), Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=0), ] policy = SchedulePolicy( policy="lof", tree_cache=tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=True, schedule_low_priority_values_first=False, ) policy.calc_priority(waiting_queue) # Check if priority enabled fcfs ordering is applied. self.assertEqual(waiting_queue[0].rid, 1) self.assertEqual(waiting_queue[1].rid, 2) self.assertEqual(waiting_queue[2].rid, 3) def test_calc_priority_priority_enabled_longest_output_first_scheduling_with_low_priority_values_first( self, ): tree_cache = RadixCache(None, None, False) waiting_queue = [ Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=0), Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=1), Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=1), ] policy = SchedulePolicy( policy="lof", tree_cache=tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=True, schedule_low_priority_values_first=True, ) policy.calc_priority(waiting_queue) # Check if priority enabled fcfs ordering is applied. self.assertEqual(waiting_queue[0].rid, 1) self.assertEqual(waiting_queue[1].rid, 2) self.assertEqual(waiting_queue[2].rid, 3) if __name__ == "__main__": unittest.main()