test_speculative_registry.py 5.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import unittest

from sglang.srt.speculative import spec_info as spec_info_module
from sglang.srt.speculative.spec_info import (
    SpeculativeAlgorithm,
    register_speculative_algorithm,
)


class DummyWorker:
    def __init__(self, **kwargs):
        self.kwargs = kwargs


class SpeculativeRegistryTests(unittest.TestCase):
    def test_nextn_alias_maps_to_eagle(self):
        eagle = SpeculativeAlgorithm.from_string("EAGLE")
        alias = SpeculativeAlgorithm.from_string("NEXTN")
        self.assertIs(alias, eagle)

    def test_register_speculative_algorithm_registers_worker_and_flags(self):
        original_next_value = SpeculativeAlgorithm._next_value
        algo = register_speculative_algorithm(
            "TEST_SPEC_ALGO",
            DummyWorker,
            aliases=("TEST_SPEC_ALIAS",),
            flags=("EAGLE",),
            override_worker=True,
        )
        self.addCleanup(self._cleanup_registered_algorithm, algo, ("TEST_SPEC_ALIAS",))
        self.addCleanup(
            setattr, SpeculativeAlgorithm, "_next_value", original_next_value
        )

        self.assertIs(SpeculativeAlgorithm.from_string("TEST_SPEC_ALGO"), algo)
        self.assertIs(SpeculativeAlgorithm.from_string("TEST_SPEC_ALIAS"), algo)
        self.assertTrue(algo.is_eagle())
        self.assertIs(SpeculativeAlgorithm.from_value(int(algo)), algo)
        self.assertIn(algo, list(spec_info_module._REGISTERED_WORKERS))

        worker = algo.create_draft_worker(example_arg=42)
        self.assertIsInstance(worker, DummyWorker)
        self.assertEqual(worker.kwargs["example_arg"], 42)

    def test_builtin_algorithms_flags_and_factories(self):
        cases = {
            "NONE": {
                "is_none": True,
                "is_eagle": False,
                "is_eagle3": False,
                "is_standalone": False,
                "is_ngram": False,
                "has_factory": False,
            },
            "EAGLE": {
                "is_none": False,
                "is_eagle": True,
                "is_eagle3": False,
                "is_standalone": False,
                "is_ngram": False,
                "has_factory": True,
            },
            "EAGLE3": {
                "is_none": False,
                "is_eagle": True,
                "is_eagle3": True,
                "is_standalone": False,
                "is_ngram": False,
                "has_factory": True,
            },
            "STANDALONE": {
                "is_none": False,
                "is_eagle": False,
                "is_eagle3": False,
                "is_standalone": True,
                "is_ngram": False,
                "has_factory": True,
            },
            "NGRAM": {
                "is_none": False,
                "is_eagle": False,
                "is_eagle3": False,
                "is_standalone": False,
                "is_ngram": True,
                "has_factory": True,
            },
        }

        for name, expectations in cases.items():
            with self.subTest(name=name):
                algo = SpeculativeAlgorithm.from_string(name)
                self.assertEqual(algo.name, name)
                self.assertEqual(algo.is_none(), expectations["is_none"])
                self.assertEqual(algo.is_eagle(), expectations["is_eagle"])
                self.assertEqual(algo.is_eagle3(), expectations["is_eagle3"])
                self.assertEqual(algo.is_standalone(), expectations["is_standalone"])
                self.assertEqual(algo.is_ngram(), expectations["is_ngram"])

                has_factory = algo._draft_worker_factory is not None
                self.assertEqual(has_factory, expectations["has_factory"])
                self.assertIs(SpeculativeAlgorithm.from_value(int(algo)), algo)

        self.assertIs(SpeculativeAlgorithm.from_string(None), SpeculativeAlgorithm.NONE)

    def test_iteration_returns_registration_order(self):
        names = [algo.name for algo in SpeculativeAlgorithm._registration_order]
        for required in ["NONE", "EAGLE", "EAGLE3", "STANDALONE", "NGRAM"]:
            self.assertIn(required, names)

    def test_create_draft_worker_returns_none_for_none_algorithm(self):
        self.assertIsNone(SpeculativeAlgorithm.NONE.create_draft_worker())

    def test_register_draft_worker_override(self):
        algo = SpeculativeAlgorithm.from_string("EAGLE")
        original_factory = algo._draft_worker_factory

        def dummy_factory(_: SpeculativeAlgorithm, **kwargs):
            return "dummy"

        SpeculativeAlgorithm.register_draft_worker(algo, dummy_factory)
        self.addCleanup(
            SpeculativeAlgorithm.register_draft_worker, algo, original_factory
        )

        self.assertEqual(algo.create_draft_worker(), "dummy")

    def _cleanup_registered_algorithm(self, algorithm: SpeculativeAlgorithm, aliases):
        name = algorithm.name
        SpeculativeAlgorithm._registry_by_value.pop(algorithm.value, None)
        SpeculativeAlgorithm._registry_by_name.pop(name, None)
        if hasattr(SpeculativeAlgorithm, name):
            delattr(SpeculativeAlgorithm, name)

        for alias in aliases:
            SpeculativeAlgorithm._registry_by_name.pop(alias, None)

        try:
            SpeculativeAlgorithm._registration_order.remove(algorithm)
        except ValueError:
            pass

        for flag_values in SpeculativeAlgorithm._flags.values():
            flag_values.discard(algorithm.value)

        spec_info_module._REGISTERED_WORKERS.pop(algorithm, None)


if __name__ == "__main__":
    unittest.main()