test_fa3.py 9.71 KB
Newer Older
Stefan He's avatar
Stefan He committed
1
import os
2
3
4
5
6
7
8
9
10
11
import unittest
from types import SimpleNamespace

import requests
import torch

from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
Stefan He's avatar
Stefan He committed
12
13
14
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
15
16
17
18
19
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)

Stefan He's avatar
Stefan He committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
GSM_DATASET_PATH = None

# In case of some machine lack internet connection, we can set OFFLINE_MODE to True.
OFFLINE_MODE = False

# Change the path below when OFFLINE_MODE is True.
OFFLINE_PATH_DICT = {
    DEFAULT_MODEL_NAME_FOR_TEST: "/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct",
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "/shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
    DEFAULT_MODEL_NAME_FOR_TEST_MLA: "/shared/public/sharing/deepseek/dsv3-test/snapshots/",
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "/shared/public/sharing/deepseek/dsv3-test-NextN/snapshots/",
    GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl",
}


if OFFLINE_MODE:
    DEFAULT_MODEL_NAME_FOR_TEST = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST]
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = OFFLINE_PATH_DICT[
        DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
    ]
    DEFAULT_MODEL_NAME_FOR_TEST_MLA = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST_MLA]
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = OFFLINE_PATH_DICT[
        DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
    ]
    GSM_DATASET_PATH = OFFLINE_PATH_DICT[GSM_DATASET_PATH]


# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS = [
    "--trust-remote-code",
    "--enable-torch-compile",
    "--cuda-graph-max-bs",
    "2",
    "--attention-backend",
    "fa3",
]

57
58
59
60
61
62
63
"""
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py
"""


@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
class BaseFlashAttentionTest(unittest.TestCase):
Stefan He's avatar
Stefan He committed
64
    """Base class for testing FlashAttention3."""
65

Stefan He's avatar
Stefan He committed
66
    model = DEFAULT_MODEL_NAME_FOR_TEST
67
    base_url = DEFAULT_URL_FOR_TEST
Stefan He's avatar
Stefan He committed
68
69
70
    accuracy_threshold = 0.65  # derived tests need to override this
    speculative_decode = False
    spec_decode_threshold = 1.0  # derived spec decoding tests need to override this
71
72
73
74

    @classmethod
    def get_server_args(cls):
        """Return the arguments for the server launch. Override in subclasses."""
Stefan He's avatar
Stefan He committed
75
        return DEFAULT_SERVER_ARGS
76
77
78

    @classmethod
    def setUpClass(cls):
Stefan He's avatar
Stefan He committed
79
80
81
        # disable deep gemm precompile to make launch server faster
        # please don't do this if you want to make your inference workload faster
        os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "False"
82
83
84
85
86
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=cls.get_server_args(),
Stefan He's avatar
Stefan He committed
87
            env=os.environ,
88
89
90
91
92
93
94
95
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def test_gsm8k(self):
        args = SimpleNamespace(
Stefan He's avatar
Stefan He committed
96
97
            num_shots=4,
            num_questions=100,
98
99
100
101
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
Stefan He's avatar
Stefan He committed
102
            data_path=GSM_DATASET_PATH,
103
104
105
106
107
108
109
110
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(metrics)

        # Use the appropriate metric key based on the test class
        metric_key = "accuracy"
        self.assertGreater(metrics[metric_key], self.accuracy_threshold)

Stefan He's avatar
Stefan He committed
111
112
113
114
115
116
117
118
119
        if self.speculative_decode:
            server_info = requests.get(self.base_url + "/get_server_info")
            avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
            print(f"{avg_spec_accept_length=}")
            self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)


class TestFlashAttention3MLA(BaseFlashAttentionTest):
    """Test FlashAttention3 with MLA, e.g. deepseek v3 test model"""
120

Stefan He's avatar
Stefan He committed
121
122
    accuracy_threshold = 0.60
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
123
124
125

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
126
127
        return DEFAULT_SERVER_ARGS

128

Stefan He's avatar
Stefan He committed
129
130
class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled with Llama 3.1 8B and its eagle3 model"""
131

Stefan He's avatar
Stefan He committed
132
133
134
135
    model = DEFAULT_MODEL_NAME_FOR_TEST
    accuracy_threshold = 0.65
    speculative_decode = True
    spec_decode_threshold = 1.5
136
137
138

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
139
        args = DEFAULT_SERVER_ARGS
140
141
142
143
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
Stefan He's avatar
Stefan He committed
144
145
146
147
148
149
150
151
152
153
154
155
                "--speculative-algorithm",
                "EAGLE3",
                "--speculative-draft",
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
                "--speculative-num-steps",
                "3",
                "--speculative-eagle-topk",
                "1",
                "--speculative-num-draft-tokens",
                "4",
                "--dtype",
                "float16",
156
157
158
159
160
            ]
        )
        return args


Stefan He's avatar
Stefan He committed
161
162
163
164
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Tests FlashAttention3 with enhanced speculative decoding using Llama 3.1 8B and EAGLE3.
    This test will be using top-k value > 1 which would verify the other branches of the FA3 code
    """
165

Stefan He's avatar
Stefan He committed
166
167
168
169
    model = DEFAULT_MODEL_NAME_FOR_TEST
    accuracy_threshold = 0.65
    speculative_decode = True
    spec_decode_threshold = 1.5
170
171
172

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
173
        args = DEFAULT_SERVER_ARGS
174
175
176
177
178
179
180
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE3",
                "--speculative-draft",
Stefan He's avatar
Stefan He committed
181
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
182
                "--speculative-num-steps",
Stefan He's avatar
Stefan He committed
183
                "5",
184
                "--speculative-eagle-topk",
Stefan He's avatar
Stefan He committed
185
                "4",
186
                "--speculative-num-draft-tokens",
Stefan He's avatar
Stefan He committed
187
                "8",
188
189
190
191
192
193
194
                "--dtype",
                "float16",
            ]
        )
        return args


195
196
197
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled, topk > 1"""

Stefan He's avatar
Stefan He committed
198
    model = DEFAULT_MODEL_NAME_FOR_TEST
199
200
201
202
203
204
205
206
207
208
209

    @classmethod
    def get_server_args(cls):
        args = super().get_server_args()
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE3",
                "--speculative-draft",
Stefan He's avatar
Stefan He committed
210
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
                "--speculative-num-steps",
                "5",
                "--speculative-eagle-topk",
                "4",
                "--speculative-num-draft-tokens",
                "8",
                "--dtype",
                "float16",
            ]
        )
        return args

    def test_gsm8k(self):
        """
        Override the test_gsm8k to further test for average speculative accept length.
        """
        requests.get(self.base_url + "/flush_cache")

        args = SimpleNamespace(
            num_shots=5,
Stefan He's avatar
Stefan He committed
231
            data_path=GSM_DATASET_PATH,
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(metrics)

        self.assertGreater(metrics["accuracy"], 0.60)

        server_info = requests.get(self.base_url + "/get_server_info")
        avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
        print(f"{avg_spec_accept_length=}")
        self.assertGreater(avg_spec_accept_length, 1.8)


249
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
Stefan He's avatar
Stefan He committed
250
    """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
251

Stefan He's avatar
Stefan He committed
252
253
254
255
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
    accuracy_threshold = 0.60
    speculative_decode = True
    spec_decode_threshold = 1.5
256
257
258

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
259
        args = DEFAULT_SERVER_ARGS
260
261
262
263
264
265
266
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft",
Stefan He's avatar
Stefan He committed
267
                DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
268
269
270
271
272
                "--speculative-num-steps",
                "3",
                "--speculative-eagle-topk",
                "1",
                "--speculative-num-draft-tokens",
Stefan He's avatar
Stefan He committed
273
                "4",
274
275
276
277
278
            ]
        )
        return args


Stefan He's avatar
Stefan He committed
279
280
281
282
class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model
    This test will be using top-k value > 1 which would verify the other branches of the FA3 code
    """
283

Stefan He's avatar
Stefan He committed
284
285
286
287
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
    accuracy_threshold = 0.60
    speculative_decode = True
    spec_decode_threshold = 1.5
288

Stefan He's avatar
Stefan He committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    @classmethod
    def get_server_args(cls):
        args = DEFAULT_SERVER_ARGS
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft",
                DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
                "--speculative-num-steps",
                "5",
                "--speculative-eagle-topk",
                "4",
                "--speculative-num-draft-tokens",
                "8",
            ]
        )
        return args
309
310


311
312
if __name__ == "__main__":
    unittest.main()