test_triton_sliding_window.py 3.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import unittest
from types import SimpleNamespace

import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    CustomTestCase,
12
    is_in_ci,
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    popen_launch_server,
)


class TestSlidingWindowAttentionTriton(CustomTestCase):
    """Test sliding window attention functionality with triton backend."""

    @classmethod
    def setUpClass(cls):
        """Set up the test server with Gemma3 model and triton backend."""
        # Gemma3 model supports sliding window attention
        cls.model = "google/gemma-3-4b-it"
        cls.base_url = DEFAULT_URL_FOR_TEST

        cls.common_args = [
            "--trust-remote-code",
            "--attention-backend",
            "triton",
            "--context-length",
            "8192",
            "--random-seed",
            "42",
        ]

        cls.short_context_prompt = "The capital of France is"

        # Test prompt longer than window size
        cls.long_context_prompt = (
            """
        Once upon a time, there was a mountain. In the mountain, there was a temple. In the temple, there was an old monk telling a story. The story was:
        """
            * 100
        )
        cls.long_context_prompt += "\nNow, summarize the story in one sentence:"

    def _test_mmlu(self):
        args = SimpleNamespace(
            base_url=self.base_url,
            model=self.model,
            eval_name="mmlu",
53
            num_examples=200,
54
55
56
57
58
59
            num_threads=32,
        )

        metrics = run_eval(args)
        print(f"MMLU metrics with sliding window: {metrics}")

60
        self.assertGreaterEqual(metrics["score"], 0.60)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    def _test_short_context_generation(self):
        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": self.short_context_prompt,
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": 256,
                },
            },
        )

        self.assertEqual(response.status_code, 200)
        result = response.json()
        self.assertIn("paris", result["text"].lower())
        print(f"Short context generation result: {result['text']}")

    def _test_long_context_generation(self):
        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": self.long_context_prompt,
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": 256,
                },
            },
        )

        self.assertEqual(response.status_code, 200)
        result = response.json()
        self.assertGreater(len(result["text"].strip()), 0)
        print(f"Long context generation result: {result['text'][:100]}...")

96
    @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
97
98
99
100
101
102
103
104
    def test_no_cuda_graph(self):
        self.no_cuda_graph_process = popen_launch_server(
            self.model,
            self.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=self.common_args + ["--disable-cuda-graph"],
        )

105
106
107
108
109
110
        try:
            self._test_short_context_generation()
            self._test_long_context_generation()
            self._test_mmlu()
        finally:
            kill_process_tree(self.no_cuda_graph_process.pid)
111
112
113
114
115
116
117
118
119

    def test_cuda_graph(self):
        self.cuda_graph_process = popen_launch_server(
            self.model,
            self.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=self.common_args,
        )

120
121
122
123
124
125
        try:
            self._test_short_context_generation()
            self._test_long_context_generation()
            self._test_mmlu()
        finally:
            kill_process_tree(self.cuda_graph_process.pid)
126
127
128
129


if __name__ == "__main__":
    unittest.main()