test_fa3.py 8.3 KB
Newer Older
1
2
3
4
5
import unittest
from types import SimpleNamespace

import requests

6
from sglang.srt.environ import envs
7
8
9
10
from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
Stefan He's avatar
Stefan He committed
11
12
13
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
14
15
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
Lianmin Zheng's avatar
Lianmin Zheng committed
16
    CustomTestCase,
17
18
19
    popen_launch_server,
)

Stefan He's avatar
Stefan He committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
GSM_DATASET_PATH = None

# In case of some machine lack internet connection, we can set OFFLINE_MODE to True.
OFFLINE_MODE = False

# Change the path below when OFFLINE_MODE is True.
OFFLINE_PATH_DICT = {
    DEFAULT_MODEL_NAME_FOR_TEST: "/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct",
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "/shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
    DEFAULT_MODEL_NAME_FOR_TEST_MLA: "/shared/public/sharing/deepseek/dsv3-test/snapshots/",
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "/shared/public/sharing/deepseek/dsv3-test-NextN/snapshots/",
    GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl",
}


if OFFLINE_MODE:
    DEFAULT_MODEL_NAME_FOR_TEST = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST]
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = OFFLINE_PATH_DICT[
        DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
    ]
    DEFAULT_MODEL_NAME_FOR_TEST_MLA = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST_MLA]
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = OFFLINE_PATH_DICT[
        DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
    ]
    GSM_DATASET_PATH = OFFLINE_PATH_DICT[GSM_DATASET_PATH]


# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS = [
    "--trust-remote-code",
    "--cuda-graph-max-bs",
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    "8",
Stefan He's avatar
Stefan He committed
52
53
54
55
    "--attention-backend",
    "fa3",
]

56
57
58
59
60
61
"""
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py
"""


@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
Lianmin Zheng's avatar
Lianmin Zheng committed
62
class BaseFlashAttentionTest(CustomTestCase):
Stefan He's avatar
Stefan He committed
63
    """Base class for testing FlashAttention3."""
64

Stefan He's avatar
Stefan He committed
65
    model = DEFAULT_MODEL_NAME_FOR_TEST
66
    base_url = DEFAULT_URL_FOR_TEST
Stefan He's avatar
Stefan He committed
67
68
69
    accuracy_threshold = 0.65  # derived tests need to override this
    speculative_decode = False
    spec_decode_threshold = 1.0  # derived spec decoding tests need to override this
70
71
72
73

    @classmethod
    def get_server_args(cls):
        """Return the arguments for the server launch. Override in subclasses."""
Stefan He's avatar
Stefan He committed
74
        return DEFAULT_SERVER_ARGS
75
76
77

    @classmethod
    def setUpClass(cls):
Stefan He's avatar
Stefan He committed
78
79
        # disable deep gemm precompile to make launch server faster
        # please don't do this if you want to make your inference workload faster
80
81
82
83
84
85
86
87
88
89
        with (
            envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.override(False),
            envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False),
        ):
            cls.process = popen_launch_server(
                cls.model,
                cls.base_url,
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
                other_args=cls.get_server_args(),
            )
90
91
92
93
94
95

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def test_gsm8k(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
96
97
        requests.get(self.base_url + "/flush_cache")

98
        args = SimpleNamespace(
Stefan He's avatar
Stefan He committed
99
100
            num_shots=4,
            num_questions=100,
101
102
103
104
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
Stefan He's avatar
Stefan He committed
105
            data_path=GSM_DATASET_PATH,
106
107
        )
        metrics = run_eval_few_shot_gsm8k(args)
Lianmin Zheng's avatar
Lianmin Zheng committed
108
        print(f"{metrics=}")
109
110
111
112
113

        # Use the appropriate metric key based on the test class
        metric_key = "accuracy"
        self.assertGreater(metrics[metric_key], self.accuracy_threshold)

Stefan He's avatar
Stefan He committed
114
115
        if self.speculative_decode:
            server_info = requests.get(self.base_url + "/get_server_info")
116
117
118
            avg_spec_accept_length = server_info.json()["internal_states"][0][
                "avg_spec_accept_length"
            ]
Stefan He's avatar
Stefan He committed
119
120
121
122
123
124
            print(f"{avg_spec_accept_length=}")
            self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)


class TestFlashAttention3MLA(BaseFlashAttentionTest):
    """Test FlashAttention3 with MLA, e.g. deepseek v3 test model"""
125

Stefan He's avatar
Stefan He committed
126
127
    accuracy_threshold = 0.60
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
128
129
130

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
131
132
        return DEFAULT_SERVER_ARGS

133

Stefan He's avatar
Stefan He committed
134
135
class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled with Llama 3.1 8B and its eagle3 model"""
136

Stefan He's avatar
Stefan He committed
137
138
139
140
    model = DEFAULT_MODEL_NAME_FOR_TEST
    accuracy_threshold = 0.65
    speculative_decode = True
    spec_decode_threshold = 1.5
141
142
143

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
144
        args = DEFAULT_SERVER_ARGS
145
146
147
        args.extend(
            [
                "--cuda-graph-max-bs",
148
                "4",
Stefan He's avatar
Stefan He committed
149
150
                "--speculative-algorithm",
                "EAGLE3",
151
                "--speculative-draft-model-path",
Stefan He's avatar
Stefan He committed
152
153
154
155
156
157
158
159
160
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
                "--speculative-num-steps",
                "3",
                "--speculative-eagle-topk",
                "1",
                "--speculative-num-draft-tokens",
                "4",
                "--dtype",
                "float16",
161
162
163
164
165
            ]
        )
        return args


Stefan He's avatar
Stefan He committed
166
167
168
169
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Tests FlashAttention3 with enhanced speculative decoding using Llama 3.1 8B and EAGLE3.
    This test will be using top-k value > 1 which would verify the other branches of the FA3 code
    """
170

Stefan He's avatar
Stefan He committed
171
172
173
    model = DEFAULT_MODEL_NAME_FOR_TEST
    accuracy_threshold = 0.65
    speculative_decode = True
174
    spec_decode_threshold = 1.6
175
176
177

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
178
        args = DEFAULT_SERVER_ARGS
179
180
181
        args.extend(
            [
                "--cuda-graph-max-bs",
182
                "4",
183
184
                "--speculative-algorithm",
                "EAGLE3",
185
                "--speculative-draft-model-path",
Stefan He's avatar
Stefan He committed
186
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
187
                "--speculative-num-steps",
Stefan He's avatar
Stefan He committed
188
                "5",
189
                "--speculative-eagle-topk",
Stefan He's avatar
Stefan He committed
190
                "4",
191
                "--speculative-num-draft-tokens",
Stefan He's avatar
Stefan He committed
192
                "8",
193
194
195
196
197
198
199
                "--dtype",
                "float16",
            ]
        )
        return args


200
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
Stefan He's avatar
Stefan He committed
201
    """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
202

Stefan He's avatar
Stefan He committed
203
204
205
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
    accuracy_threshold = 0.60
    speculative_decode = True
206
    spec_decode_threshold = 2.5
207
208
209

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
210
        args = DEFAULT_SERVER_ARGS
211
212
213
        args.extend(
            [
                "--cuda-graph-max-bs",
214
                "4",
215
216
                "--speculative-algorithm",
                "EAGLE",
217
                "--speculative-draft-model-path",
Stefan He's avatar
Stefan He committed
218
                DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
219
220
221
222
223
                "--speculative-num-steps",
                "3",
                "--speculative-eagle-topk",
                "1",
                "--speculative-num-draft-tokens",
Stefan He's avatar
Stefan He committed
224
                "4",
225
226
227
228
229
            ]
        )
        return args


Stefan He's avatar
Stefan He committed
230
231
232
233
class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model
    This test will be using top-k value > 1 which would verify the other branches of the FA3 code
    """
234

Stefan He's avatar
Stefan He committed
235
236
237
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
    accuracy_threshold = 0.60
    speculative_decode = True
238
    spec_decode_threshold = 2.95
239

Stefan He's avatar
Stefan He committed
240
241
242
243
244
245
    @classmethod
    def get_server_args(cls):
        args = DEFAULT_SERVER_ARGS
        args.extend(
            [
                "--cuda-graph-max-bs",
246
                "4",
Stefan He's avatar
Stefan He committed
247
248
                "--speculative-algorithm",
                "EAGLE",
249
                "--speculative-draft-model-path",
Stefan He's avatar
Stefan He committed
250
251
252
253
254
255
256
257
258
259
                DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
                "--speculative-num-steps",
                "5",
                "--speculative-eagle-topk",
                "4",
                "--speculative-num-draft-tokens",
                "8",
            ]
        )
        return args
260
261


262
263
if __name__ == "__main__":
    unittest.main()