test_matched_stop.py 2.88 KB
Newer Older
1
2
import unittest

3
from sglang.srt.sampling.sampling_params import MAX_LEN, get_max_seq_length
4
from sglang.srt.utils import kill_process_tree
5
from sglang.test.kits.matched_stop_kit import MatchedStopMixin
6
7
8
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
    DEFAULT_URL_FOR_TEST,
9
    CustomTestCase,
10
11
12
13
    popen_launch_server,
)


14
class TestMatchedStop(CustomTestCase, MatchedStopMixin):
15
16
17
18
19
20
21
22
23
24
25
26
27
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=300,
            other_args=["--max-running-requests", "10"],
        )

    @classmethod
    def tearDownClass(cls):
28
        kill_process_tree(cls.process.pid)
29
30


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class TestRegexPatternMaxLength(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.regex_str_to_max_len = {
            "((ab|cd(e|f){2}){3,5}g|hij)*k": MAX_LEN,
            # - '*' → infinite tokens need to be stored
            "abc*?k": MAX_LEN,
            # - '*?' → infinite tokens still need to be stored even if lazy matching used
            "^spec(foo|at)$": 7,
            # - '^' and '$' don't add any characters to the max length
            # "spec" → 4
            # "(foo|at)" → max(3, 2) = 3
            # Whole regex = 7
            "(a(bca|de(fg|hi){2,3})j){2}kl": 22,
            # - Innermost alt: "fg" vs "hi" → 2
            # - Repeat {2,3}: max = 3 * 2 = 6
            # - Inner group "de(...)": 2 (for "de") + 6 = 8.
            # - "bca" or "de(...)" → max(3, 8) = 8
            # - Whole group: "a" (1) + group (8) + "j"(1) = 10
            # - Repeat {2} → 20
            # - Add "kl"(2) → 22
            "(foo(bar|baz(qux){1,2}))|(x(yz){5,10})": 21,
            # Branch 1:
            #   "foo"(3) + max("bar"(3), "baz"(3)+"qux"{2} = 3 + 6 = 9) = 3 + 9 = 12
            # Branch 2:
            #   "x"(1) + "yz"{10} = 1 + 20 =21
            # Whole regex = max(12, 21) = 21
            "(((a|bc){1,3}(d(e|f){2}|gh){2,4})|(ijk|lmp(no|p){3})){5}": 90,
            # Branch A:
            #   (a|bc){1,3} → max = 3 * 2 = 6
            #   Inside: d(e|f){2} = 1 + 2 * 1 = 3 vs gh = 2 → max = 3
            #   Repeat {2,4} → 4 * 3 = 12
            #   Branch A total = 18
            # Branch B:
            #   "ijk"(3) vs "lmp(no|p){3}" = 3 + 3 * max(2, 1) = 3 + 6 = 9 → max = 9
            #   Branch B total = 9
            # Whole outer alt = max(18, 9) = 18
            # Repeat {5} → 90
        }

    def test_get_max_length(self):
        for regex_str, max_len in self.regex_str_to_max_len.items():
            if max_len == MAX_LEN:
                self.assertGreaterEqual(get_max_seq_length(regex_str), MAX_LEN)
            else:
                self.assertEqual(get_max_seq_length(regex_str), max_len)


79
80
if __name__ == "__main__":
    unittest.main()