test_request_queue_validation.py 2.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import asyncio
import os
import re
import unittest

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    STDERR_FILENAME,
    STDOUT_FILENAME,
    CustomTestCase,
    popen_launch_server,
    send_concurrent_generate_requests,
    send_generate_requests,
)


class TestMaxQueuedRequests(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST

        cls.stdout = open(STDOUT_FILENAME, "w")
        cls.stderr = open(STDERR_FILENAME, "w")

        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=(
                "--max-running-requests",  # Enforce max request concurrency is 1
                "1",
                "--max-queued-requests",  # Enforce max queued request number is 1
                "1",
Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
                "--attention-backend",
                "triton",
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
            ),
            return_stdout_stderr=(cls.stdout, cls.stderr),
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)
        cls.stdout.close()
        cls.stderr.close()
        os.remove(STDOUT_FILENAME)
        os.remove(STDERR_FILENAME)

    def test_max_queued_requests_validation_with_serial_requests(self):
        """Verify request is not throttled when the max concurrency is 1."""
        status_codes = send_generate_requests(
            self.base_url,
            num_requests=10,
        )

        for status_code in status_codes:
            assert status_code == 200  # request shouldn't be throttled

    def test_max_queued_requests_validation_with_concurrent_requests(self):
        """Verify request throttling with concurrent requests."""
        status_codes = asyncio.run(
            send_concurrent_generate_requests(self.base_url, num_requests=10)
        )

69
70
        expected_status_codes = [200, 200, 503, 503, 503, 503, 503, 503, 503, 503]
        assert status_codes == expected_status_codes
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    def test_max_running_requests_and_max_queued_request_validation(self):
        """Verify running request and queued request numbers based on server logs."""
        rr_pattern = re.compile(r"#running-req:\s*(\d+)")
        qr_pattern = re.compile(r"#queue-req:\s*(\d+)")

        with open(STDERR_FILENAME) as lines:
            for line in lines:
                rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line)
                if rr_match:
                    assert int(rr_match.group(1)) <= 1
                if qr_match:
                    assert int(qr_match.group(1)) <= 1


if __name__ == "__main__":
    unittest.main()