test_fa3.py 8.12 KB
Newer Older
Stefan He's avatar
Stefan He committed
1
import os
2
3
4
5
6
7
8
9
10
import unittest
from types import SimpleNamespace

import requests

from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
Stefan He's avatar
Stefan He committed
11
12
13
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
14
15
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
Lianmin Zheng's avatar
Lianmin Zheng committed
16
    CustomTestCase,
17
18
19
    popen_launch_server,
)

Stefan He's avatar
Stefan He committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
GSM_DATASET_PATH = None

# In case of some machine lack internet connection, we can set OFFLINE_MODE to True.
OFFLINE_MODE = False

# Change the path below when OFFLINE_MODE is True.
OFFLINE_PATH_DICT = {
    DEFAULT_MODEL_NAME_FOR_TEST: "/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct",
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "/shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
    DEFAULT_MODEL_NAME_FOR_TEST_MLA: "/shared/public/sharing/deepseek/dsv3-test/snapshots/",
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "/shared/public/sharing/deepseek/dsv3-test-NextN/snapshots/",
    GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl",
}


if OFFLINE_MODE:
    DEFAULT_MODEL_NAME_FOR_TEST = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST]
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = OFFLINE_PATH_DICT[
        DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
    ]
    DEFAULT_MODEL_NAME_FOR_TEST_MLA = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST_MLA]
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = OFFLINE_PATH_DICT[
        DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
    ]
    GSM_DATASET_PATH = OFFLINE_PATH_DICT[GSM_DATASET_PATH]


# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS = [
    "--trust-remote-code",
    "--cuda-graph-max-bs",
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    "4",
Stefan He's avatar
Stefan He committed
52
53
54
55
    "--attention-backend",
    "fa3",
]

56
57
58
59
60
61
"""
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py
"""


@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
Lianmin Zheng's avatar
Lianmin Zheng committed
62
class BaseFlashAttentionTest(CustomTestCase):
Stefan He's avatar
Stefan He committed
63
    """Base class for testing FlashAttention3."""
64

Stefan He's avatar
Stefan He committed
65
    model = DEFAULT_MODEL_NAME_FOR_TEST
66
    base_url = DEFAULT_URL_FOR_TEST
Stefan He's avatar
Stefan He committed
67
68
69
    accuracy_threshold = 0.65  # derived tests need to override this
    speculative_decode = False
    spec_decode_threshold = 1.0  # derived spec decoding tests need to override this
70
71
72
73

    @classmethod
    def get_server_args(cls):
        """Return the arguments for the server launch. Override in subclasses."""
Stefan He's avatar
Stefan He committed
74
        return DEFAULT_SERVER_ARGS
75
76
77

    @classmethod
    def setUpClass(cls):
Stefan He's avatar
Stefan He committed
78
79
        # disable deep gemm precompile to make launch server faster
        # please don't do this if you want to make your inference workload faster
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
        os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
        os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
82
83
84
85
86
87
88
89
90
91
92
93
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=cls.get_server_args(),
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def test_gsm8k(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
        requests.get(self.base_url + "/flush_cache")

96
        args = SimpleNamespace(
Stefan He's avatar
Stefan He committed
97
98
            num_shots=4,
            num_questions=100,
99
100
101
102
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
Stefan He's avatar
Stefan He committed
103
            data_path=GSM_DATASET_PATH,
104
105
        )
        metrics = run_eval_few_shot_gsm8k(args)
Lianmin Zheng's avatar
Lianmin Zheng committed
106
        print(f"{metrics=}")
107
108
109
110
111

        # Use the appropriate metric key based on the test class
        metric_key = "accuracy"
        self.assertGreater(metrics[metric_key], self.accuracy_threshold)

Stefan He's avatar
Stefan He committed
112
113
114
115
116
117
118
119
120
        if self.speculative_decode:
            server_info = requests.get(self.base_url + "/get_server_info")
            avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
            print(f"{avg_spec_accept_length=}")
            self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)


class TestFlashAttention3MLA(BaseFlashAttentionTest):
    """Test FlashAttention3 with MLA, e.g. deepseek v3 test model"""
121

Stefan He's avatar
Stefan He committed
122
123
    accuracy_threshold = 0.60
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
124
125
126

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
127
128
        return DEFAULT_SERVER_ARGS

129

Stefan He's avatar
Stefan He committed
130
131
class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled with Llama 3.1 8B and its eagle3 model"""
132

Stefan He's avatar
Stefan He committed
133
134
135
136
    model = DEFAULT_MODEL_NAME_FOR_TEST
    accuracy_threshold = 0.65
    speculative_decode = True
    spec_decode_threshold = 1.5
137
138
139

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
140
        args = DEFAULT_SERVER_ARGS
141
142
143
144
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
Stefan He's avatar
Stefan He committed
145
146
147
148
149
150
151
152
153
154
155
156
                "--speculative-algorithm",
                "EAGLE3",
                "--speculative-draft",
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
                "--speculative-num-steps",
                "3",
                "--speculative-eagle-topk",
                "1",
                "--speculative-num-draft-tokens",
                "4",
                "--dtype",
                "float16",
157
158
159
160
161
            ]
        )
        return args


Stefan He's avatar
Stefan He committed
162
163
164
165
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Tests FlashAttention3 with enhanced speculative decoding using Llama 3.1 8B and EAGLE3.
    This test will be using top-k value > 1 which would verify the other branches of the FA3 code
    """
166

Stefan He's avatar
Stefan He committed
167
168
169
170
    model = DEFAULT_MODEL_NAME_FOR_TEST
    accuracy_threshold = 0.65
    speculative_decode = True
    spec_decode_threshold = 1.5
171
172
173

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
174
        args = DEFAULT_SERVER_ARGS
175
176
177
178
179
180
181
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE3",
                "--speculative-draft",
Stefan He's avatar
Stefan He committed
182
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
183
                "--speculative-num-steps",
Stefan He's avatar
Stefan He committed
184
                "5",
185
                "--speculative-eagle-topk",
Stefan He's avatar
Stefan He committed
186
                "4",
187
                "--speculative-num-draft-tokens",
Stefan He's avatar
Stefan He committed
188
                "8",
189
190
191
192
193
194
195
                "--dtype",
                "float16",
            ]
        )
        return args


196
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
Stefan He's avatar
Stefan He committed
197
    """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
198

Stefan He's avatar
Stefan He committed
199
200
201
202
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
    accuracy_threshold = 0.60
    speculative_decode = True
    spec_decode_threshold = 1.5
203
204
205

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
206
        args = DEFAULT_SERVER_ARGS
207
208
209
210
211
212
213
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft",
Stefan He's avatar
Stefan He committed
214
                DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
215
216
217
218
219
                "--speculative-num-steps",
                "3",
                "--speculative-eagle-topk",
                "1",
                "--speculative-num-draft-tokens",
Stefan He's avatar
Stefan He committed
220
                "4",
221
222
223
224
225
            ]
        )
        return args


Stefan He's avatar
Stefan He committed
226
227
228
229
class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model
    This test will be using top-k value > 1 which would verify the other branches of the FA3 code
    """
230

Stefan He's avatar
Stefan He committed
231
232
233
234
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
    accuracy_threshold = 0.60
    speculative_decode = True
    spec_decode_threshold = 1.5
235

Stefan He's avatar
Stefan He committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    @classmethod
    def get_server_args(cls):
        args = DEFAULT_SERVER_ARGS
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft",
                DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
                "--speculative-num-steps",
                "5",
                "--speculative-eagle-topk",
                "4",
                "--speculative-num-draft-tokens",
                "8",
            ]
        )
        return args
256
257


258
259
if __name__ == "__main__":
    unittest.main()