test_fa3.py 8.17 KB
Newer Older
Stefan He's avatar
Stefan He committed
1
import os
2
3
4
5
6
7
8
9
10
import unittest
from types import SimpleNamespace

import requests

from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
Stefan He's avatar
Stefan He committed
11
12
13
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
14
15
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
Lianmin Zheng's avatar
Lianmin Zheng committed
16
    CustomTestCase,
17
18
19
    popen_launch_server,
)

Stefan He's avatar
Stefan He committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
GSM_DATASET_PATH = None

# In case of some machine lack internet connection, we can set OFFLINE_MODE to True.
OFFLINE_MODE = False

# Change the path below when OFFLINE_MODE is True.
OFFLINE_PATH_DICT = {
    DEFAULT_MODEL_NAME_FOR_TEST: "/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct",
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "/shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
    DEFAULT_MODEL_NAME_FOR_TEST_MLA: "/shared/public/sharing/deepseek/dsv3-test/snapshots/",
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "/shared/public/sharing/deepseek/dsv3-test-NextN/snapshots/",
    GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl",
}


if OFFLINE_MODE:
    DEFAULT_MODEL_NAME_FOR_TEST = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST]
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = OFFLINE_PATH_DICT[
        DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
    ]
    DEFAULT_MODEL_NAME_FOR_TEST_MLA = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST_MLA]
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = OFFLINE_PATH_DICT[
        DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
    ]
    GSM_DATASET_PATH = OFFLINE_PATH_DICT[GSM_DATASET_PATH]


# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS = [
    "--trust-remote-code",
    "--cuda-graph-max-bs",
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    "8",
Stefan He's avatar
Stefan He committed
52
53
54
55
    "--attention-backend",
    "fa3",
]

56
57
58
59
60
61
"""
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py
"""


@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
Lianmin Zheng's avatar
Lianmin Zheng committed
62
class BaseFlashAttentionTest(CustomTestCase):
Stefan He's avatar
Stefan He committed
63
    """Base class for testing FlashAttention3."""
64

Stefan He's avatar
Stefan He committed
65
    model = DEFAULT_MODEL_NAME_FOR_TEST
66
    base_url = DEFAULT_URL_FOR_TEST
Stefan He's avatar
Stefan He committed
67
68
69
    accuracy_threshold = 0.65  # derived tests need to override this
    speculative_decode = False
    spec_decode_threshold = 1.0  # derived spec decoding tests need to override this
70
71
72
73

    @classmethod
    def get_server_args(cls):
        """Return the arguments for the server launch. Override in subclasses."""
Stefan He's avatar
Stefan He committed
74
        return DEFAULT_SERVER_ARGS
75
76
77

    @classmethod
    def setUpClass(cls):
Stefan He's avatar
Stefan He committed
78
79
        # disable deep gemm precompile to make launch server faster
        # please don't do this if you want to make your inference workload faster
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
        os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
        os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
82
83
84
85
86
87
88
89
90
91
92
93
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=cls.get_server_args(),
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def test_gsm8k(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
        requests.get(self.base_url + "/flush_cache")

96
        args = SimpleNamespace(
Stefan He's avatar
Stefan He committed
97
98
            num_shots=4,
            num_questions=100,
99
100
101
102
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
Stefan He's avatar
Stefan He committed
103
            data_path=GSM_DATASET_PATH,
104
105
        )
        metrics = run_eval_few_shot_gsm8k(args)
Lianmin Zheng's avatar
Lianmin Zheng committed
106
        print(f"{metrics=}")
107
108
109
110
111

        # Use the appropriate metric key based on the test class
        metric_key = "accuracy"
        self.assertGreater(metrics[metric_key], self.accuracy_threshold)

Stefan He's avatar
Stefan He committed
112
113
        if self.speculative_decode:
            server_info = requests.get(self.base_url + "/get_server_info")
114
115
116
            avg_spec_accept_length = server_info.json()["internal_states"][0][
                "avg_spec_accept_length"
            ]
Stefan He's avatar
Stefan He committed
117
118
119
120
121
122
            print(f"{avg_spec_accept_length=}")
            self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)


class TestFlashAttention3MLA(BaseFlashAttentionTest):
    """Test FlashAttention3 with MLA, e.g. deepseek v3 test model"""
123

Stefan He's avatar
Stefan He committed
124
125
    accuracy_threshold = 0.60
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
126
127
128

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
129
130
        return DEFAULT_SERVER_ARGS

131

Stefan He's avatar
Stefan He committed
132
133
class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled with Llama 3.1 8B and its eagle3 model"""
134

Stefan He's avatar
Stefan He committed
135
136
137
138
    model = DEFAULT_MODEL_NAME_FOR_TEST
    accuracy_threshold = 0.65
    speculative_decode = True
    spec_decode_threshold = 1.5
139
140
141

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
142
        args = DEFAULT_SERVER_ARGS
143
144
145
146
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
Stefan He's avatar
Stefan He committed
147
148
149
150
151
152
153
154
155
156
157
158
                "--speculative-algorithm",
                "EAGLE3",
                "--speculative-draft",
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
                "--speculative-num-steps",
                "3",
                "--speculative-eagle-topk",
                "1",
                "--speculative-num-draft-tokens",
                "4",
                "--dtype",
                "float16",
159
160
161
162
163
            ]
        )
        return args


Stefan He's avatar
Stefan He committed
164
165
166
167
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Tests FlashAttention3 with enhanced speculative decoding using Llama 3.1 8B and EAGLE3.
    This test will be using top-k value > 1 which would verify the other branches of the FA3 code
    """
168

Stefan He's avatar
Stefan He committed
169
170
171
172
    model = DEFAULT_MODEL_NAME_FOR_TEST
    accuracy_threshold = 0.65
    speculative_decode = True
    spec_decode_threshold = 1.5
173
174
175

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
176
        args = DEFAULT_SERVER_ARGS
177
178
179
180
181
182
183
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE3",
                "--speculative-draft",
Stefan He's avatar
Stefan He committed
184
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
185
                "--speculative-num-steps",
Stefan He's avatar
Stefan He committed
186
                "5",
187
                "--speculative-eagle-topk",
Stefan He's avatar
Stefan He committed
188
                "4",
189
                "--speculative-num-draft-tokens",
Stefan He's avatar
Stefan He committed
190
                "8",
191
192
193
194
195
196
197
                "--dtype",
                "float16",
            ]
        )
        return args


198
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
Stefan He's avatar
Stefan He committed
199
    """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
200

Stefan He's avatar
Stefan He committed
201
202
203
204
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
    accuracy_threshold = 0.60
    speculative_decode = True
    spec_decode_threshold = 1.5
205
206
207

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
208
        args = DEFAULT_SERVER_ARGS
209
210
211
212
213
214
215
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft",
Stefan He's avatar
Stefan He committed
216
                DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
217
218
219
220
221
                "--speculative-num-steps",
                "3",
                "--speculative-eagle-topk",
                "1",
                "--speculative-num-draft-tokens",
Stefan He's avatar
Stefan He committed
222
                "4",
223
224
225
226
227
            ]
        )
        return args


Stefan He's avatar
Stefan He committed
228
229
230
231
class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model
    This test will be using top-k value > 1 which would verify the other branches of the FA3 code
    """
232

Stefan He's avatar
Stefan He committed
233
234
235
236
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
    accuracy_threshold = 0.60
    speculative_decode = True
    spec_decode_threshold = 1.5
237

Stefan He's avatar
Stefan He committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    @classmethod
    def get_server_args(cls):
        args = DEFAULT_SERVER_ARGS
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft",
                DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
                "--speculative-num-steps",
                "5",
                "--speculative-eagle-topk",
                "4",
                "--speculative-num-draft-tokens",
                "8",
            ]
        )
        return args
258
259


260
261
if __name__ == "__main__":
    unittest.main()