"scripts/run_wan_t2v_save_quant.sh" did not exist on "f21528e750271dfb5b8271128b1d017815034d8a"
test_fa3.py 10.4 KB
Newer Older
Stefan He's avatar
Stefan He committed
1
import os
2
3
4
5
6
7
8
9
10
11
import unittest
from types import SimpleNamespace

import requests
import torch

from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
Stefan He's avatar
Stefan He committed
12
13
14
15
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
    DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
16
17
18
19
20
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)

Stefan He's avatar
Stefan He committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
GSM_DATASET_PATH = None

# In case of some machine lack internet connection, we can set OFFLINE_MODE to True.
OFFLINE_MODE = False

# Change the path below when OFFLINE_MODE is True.
OFFLINE_PATH_DICT = {
    DEFAULT_MODEL_NAME_FOR_TEST: "/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct",
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "/shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
    DEFAULT_MODEL_NAME_FOR_TEST_MLA: "/shared/public/sharing/deepseek/dsv3-test/snapshots/",
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "/shared/public/sharing/deepseek/dsv3-test-NextN/snapshots/",
    GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl",
}


if OFFLINE_MODE:
    DEFAULT_MODEL_NAME_FOR_TEST = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST]
    DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = OFFLINE_PATH_DICT[
        DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
    ]
    DEFAULT_MODEL_NAME_FOR_TEST_MLA = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST_MLA]
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = OFFLINE_PATH_DICT[
        DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
    ]
    GSM_DATASET_PATH = OFFLINE_PATH_DICT[GSM_DATASET_PATH]


# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS = [
    "--trust-remote-code",
    "--enable-torch-compile",
    "--cuda-graph-max-bs",
    "2",
    "--attention-backend",
    "fa3",
]

58
59
60
61
62
63
64
"""
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py
"""


@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
class BaseFlashAttentionTest(unittest.TestCase):
Stefan He's avatar
Stefan He committed
65
    """Base class for testing FlashAttention3."""
66

Stefan He's avatar
Stefan He committed
67
    model = DEFAULT_MODEL_NAME_FOR_TEST
68
    base_url = DEFAULT_URL_FOR_TEST
Stefan He's avatar
Stefan He committed
69
70
71
    accuracy_threshold = 0.65  # derived tests need to override this
    speculative_decode = False
    spec_decode_threshold = 1.0  # derived spec decoding tests need to override this
72
73
74
75

    @classmethod
    def get_server_args(cls):
        """Return the arguments for the server launch. Override in subclasses."""
Stefan He's avatar
Stefan He committed
76
        return DEFAULT_SERVER_ARGS
77
78
79

    @classmethod
    def setUpClass(cls):
Stefan He's avatar
Stefan He committed
80
81
82
        # disable deep gemm precompile to make launch server faster
        # please don't do this if you want to make your inference workload faster
        os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "False"
83
84
85
86
87
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=cls.get_server_args(),
Stefan He's avatar
Stefan He committed
88
            env=os.environ,
89
90
91
92
93
94
95
96
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def test_gsm8k(self):
        args = SimpleNamespace(
Stefan He's avatar
Stefan He committed
97
98
            num_shots=4,
            num_questions=100,
99
100
101
102
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
Stefan He's avatar
Stefan He committed
103
            data_path=GSM_DATASET_PATH,
104
105
106
107
108
109
110
111
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(metrics)

        # Use the appropriate metric key based on the test class
        metric_key = "accuracy"
        self.assertGreater(metrics[metric_key], self.accuracy_threshold)

Stefan He's avatar
Stefan He committed
112
113
114
115
116
117
118
119
120
        if self.speculative_decode:
            server_info = requests.get(self.base_url + "/get_server_info")
            avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
            print(f"{avg_spec_accept_length=}")
            self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)


class TestFlashAttention3MLA(BaseFlashAttentionTest):
    """Test FlashAttention3 with MLA, e.g. deepseek v3 test model"""
121

Stefan He's avatar
Stefan He committed
122
123
    accuracy_threshold = 0.60
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
124
125
126

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
127
128
        return DEFAULT_SERVER_ARGS

129

Stefan He's avatar
Stefan He committed
130
131
class TestFlashAttention3LocalAttn(BaseFlashAttentionTest):
    """Test FlashAttention3 with Model with local attention, e.g. Llama 4."""
132

Stefan He's avatar
Stefan He committed
133
134
    accuracy_threshold = 0.70
    model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
135
136
137

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
138
139
140
141
142
143
        cloned_args = DEFAULT_SERVER_ARGS.copy()
        # remove --enable-torch-compile from cloned_args since llama4 does not support it for now
        cloned_args.remove("--enable-torch-compile")
        # we cannot use scout's 10m context due to this bug: https://github.com/sgl-project/sglang/issues/5755
        cloned_args.extend(["--tp", "4", "--context-length", "1000000"])
        return cloned_args
144
145


Stefan He's avatar
Stefan He committed
146
147
class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled with Llama 3.1 8B and its eagle3 model"""
148

Stefan He's avatar
Stefan He committed
149
150
151
152
    model = DEFAULT_MODEL_NAME_FOR_TEST
    accuracy_threshold = 0.65
    speculative_decode = True
    spec_decode_threshold = 1.5
153
154
155

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
156
        args = DEFAULT_SERVER_ARGS
157
158
159
160
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
Stefan He's avatar
Stefan He committed
161
162
163
164
165
166
167
168
169
170
171
172
                "--speculative-algorithm",
                "EAGLE3",
                "--speculative-draft",
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
                "--speculative-num-steps",
                "3",
                "--speculative-eagle-topk",
                "1",
                "--speculative-num-draft-tokens",
                "4",
                "--dtype",
                "float16",
173
174
175
176
177
            ]
        )
        return args


Stefan He's avatar
Stefan He committed
178
179
180
181
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Tests FlashAttention3 with enhanced speculative decoding using Llama 3.1 8B and EAGLE3.
    This test will be using top-k value > 1 which would verify the other branches of the FA3 code
    """
182

Stefan He's avatar
Stefan He committed
183
184
185
186
    model = DEFAULT_MODEL_NAME_FOR_TEST
    accuracy_threshold = 0.65
    speculative_decode = True
    spec_decode_threshold = 1.5
187
188
189

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
190
        args = DEFAULT_SERVER_ARGS
191
192
193
194
195
196
197
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE3",
                "--speculative-draft",
Stefan He's avatar
Stefan He committed
198
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
199
                "--speculative-num-steps",
Stefan He's avatar
Stefan He committed
200
                "5",
201
                "--speculative-eagle-topk",
Stefan He's avatar
Stefan He committed
202
                "4",
203
                "--speculative-num-draft-tokens",
Stefan He's avatar
Stefan He committed
204
                "8",
205
206
207
208
209
210
211
                "--dtype",
                "float16",
            ]
        )
        return args


212
213
214
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled, topk > 1"""

Stefan He's avatar
Stefan He committed
215
    model = DEFAULT_MODEL_NAME_FOR_TEST
216
217
218
219
220
221
222
223
224
225
226

    @classmethod
    def get_server_args(cls):
        args = super().get_server_args()
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE3",
                "--speculative-draft",
Stefan He's avatar
Stefan He committed
227
                DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
                "--speculative-num-steps",
                "5",
                "--speculative-eagle-topk",
                "4",
                "--speculative-num-draft-tokens",
                "8",
                "--dtype",
                "float16",
            ]
        )
        return args

    def test_gsm8k(self):
        """
        Override the test_gsm8k to further test for average speculative accept length.
        """
        requests.get(self.base_url + "/flush_cache")

        args = SimpleNamespace(
            num_shots=5,
Stefan He's avatar
Stefan He committed
248
            data_path=GSM_DATASET_PATH,
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(metrics)

        self.assertGreater(metrics["accuracy"], 0.60)

        server_info = requests.get(self.base_url + "/get_server_info")
        avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
        print(f"{avg_spec_accept_length=}")
        self.assertGreater(avg_spec_accept_length, 1.8)


266
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
Stefan He's avatar
Stefan He committed
267
    """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
268

Stefan He's avatar
Stefan He committed
269
270
271
272
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
    accuracy_threshold = 0.60
    speculative_decode = True
    spec_decode_threshold = 1.5
273
274
275

    @classmethod
    def get_server_args(cls):
Stefan He's avatar
Stefan He committed
276
        args = DEFAULT_SERVER_ARGS
277
278
279
280
281
282
283
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft",
Stefan He's avatar
Stefan He committed
284
                DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
285
286
287
288
289
                "--speculative-num-steps",
                "3",
                "--speculative-eagle-topk",
                "1",
                "--speculative-num-draft-tokens",
Stefan He's avatar
Stefan He committed
290
                "4",
291
292
293
294
295
            ]
        )
        return args


Stefan He's avatar
Stefan He committed
296
297
298
299
class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest):
    """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model
    This test will be using top-k value > 1 which would verify the other branches of the FA3 code
    """
300

Stefan He's avatar
Stefan He committed
301
302
303
304
    model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
    accuracy_threshold = 0.60
    speculative_decode = True
    spec_decode_threshold = 1.5
305

Stefan He's avatar
Stefan He committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    @classmethod
    def get_server_args(cls):
        args = DEFAULT_SERVER_ARGS
        args.extend(
            [
                "--cuda-graph-max-bs",
                "2",
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft",
                DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
                "--speculative-num-steps",
                "5",
                "--speculative-eagle-topk",
                "4",
                "--speculative-num-draft-tokens",
                "8",
            ]
        )
        return args
326
327


328
329
if __name__ == "__main__":
    unittest.main()