Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4d23ba08
Unverified
Commit
4d23ba08
authored
Apr 27, 2025
by
Lianmin Zheng
Committed by
GitHub
Apr 27, 2025
Browse files
Simplify FA3 tests (#5779)
parent
6e313c1b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
67 deletions
+14
-67
test/srt/run_suite.py
test/srt/run_suite.py
+2
-2
test/srt/test_fa3.py
test/srt/test_fa3.py
+8
-61
test/srt/test_local_attn.py
test/srt/test_local_attn.py
+4
-4
No files found.
test/srt/run_suite.py
View file @
4d23ba08
...
@@ -30,7 +30,7 @@ suites = {
...
@@ -30,7 +30,7 @@ suites = {
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_eagle_infer.py"
,
500
),
TestFile
(
"test_eagle_infer.py"
,
500
),
TestFile
(
"test_ebnf_constrained.py"
),
TestFile
(
"test_ebnf_constrained.py"
),
TestFile
(
"test_fa3.py"
,
5
00
),
TestFile
(
"test_fa3.py"
,
4
00
),
TestFile
(
"test_fp8_kernel.py"
,
8
),
TestFile
(
"test_fp8_kernel.py"
,
8
),
TestFile
(
"test_embedding_openai_server.py"
,
36
),
TestFile
(
"test_embedding_openai_server.py"
,
36
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_hidden_states.py"
,
55
),
...
@@ -92,7 +92,7 @@ suites = {
...
@@ -92,7 +92,7 @@ suites = {
TestFile
(
"test_verl_engine.py"
,
100
),
TestFile
(
"test_verl_engine.py"
,
100
),
],
],
"per-commit-8-gpu"
:
[
"per-commit-8-gpu"
:
[
TestFile
(
"test_local_attn.py"
,
10
0
),
TestFile
(
"test_local_attn.py"
,
25
0
),
],
],
"nightly"
:
[
"nightly"
:
[
TestFile
(
"test_nightly_gsm8k_eval.py"
),
TestFile
(
"test_nightly_gsm8k_eval.py"
),
...
...
test/srt/test_fa3.py
View file @
4d23ba08
...
@@ -3,7 +3,6 @@ import unittest
...
@@ -3,7 +3,6 @@ import unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
import
requests
import
requests
import
torch
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
...
@@ -14,6 +13,7 @@ from sglang.test.test_utils import (
...
@@ -14,6 +13,7 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
popen_launch_server
,
)
)
...
@@ -47,9 +47,8 @@ if OFFLINE_MODE:
...
@@ -47,9 +47,8 @@ if OFFLINE_MODE:
# Default server arguments shared across all tests
# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS
=
[
DEFAULT_SERVER_ARGS
=
[
"--trust-remote-code"
,
"--trust-remote-code"
,
"--enable-torch-compile"
,
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--attention-backend"
,
"--attention-backend"
,
"fa3"
,
"fa3"
,
]
]
...
@@ -60,7 +59,7 @@ Integration test for python/sglang/srt/layers/attention/flashattention_backend.p
...
@@ -60,7 +59,7 @@ Integration test for python/sglang/srt/layers/attention/flashattention_backend.p
@
unittest
.
skipIf
(
get_device_sm
()
<
90
,
"Test requires CUDA SM 90 or higher"
)
@
unittest
.
skipIf
(
get_device_sm
()
<
90
,
"Test requires CUDA SM 90 or higher"
)
class
BaseFlashAttentionTest
(
unittest
.
TestCase
):
class
BaseFlashAttentionTest
(
Custom
TestCase
):
"""Base class for testing FlashAttention3."""
"""Base class for testing FlashAttention3."""
model
=
DEFAULT_MODEL_NAME_FOR_TEST
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
@@ -78,13 +77,13 @@ class BaseFlashAttentionTest(unittest.TestCase):
...
@@ -78,13 +77,13 @@ class BaseFlashAttentionTest(unittest.TestCase):
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
# disable deep gemm precompile to make launch server faster
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
# please don't do this if you want to make your inference workload faster
os
.
environ
[
"SGL_JIT_DEEPGEMM_PRECOMPILE"
]
=
"False"
os
.
environ
[
"SGL_JIT_DEEPGEMM_PRECOMPILE"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
cls
.
process
=
popen_launch_server
(
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
model
,
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
cls
.
get_server_args
(),
other_args
=
cls
.
get_server_args
(),
env
=
os
.
environ
,
)
)
@
classmethod
@
classmethod
...
@@ -92,6 +91,8 @@ class BaseFlashAttentionTest(unittest.TestCase):
...
@@ -92,6 +91,8 @@ class BaseFlashAttentionTest(unittest.TestCase):
kill_process_tree
(
cls
.
process
.
pid
)
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
4
,
num_shots
=
4
,
num_questions
=
100
,
num_questions
=
100
,
...
@@ -102,7 +103,7 @@ class BaseFlashAttentionTest(unittest.TestCase):
...
@@ -102,7 +103,7 @@ class BaseFlashAttentionTest(unittest.TestCase):
data_path
=
GSM_DATASET_PATH
,
data_path
=
GSM_DATASET_PATH
,
)
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
print
(
f
"
{
metrics
=
}
"
)
# Use the appropriate metric key based on the test class
# Use the appropriate metric key based on the test class
metric_key
=
"accuracy"
metric_key
=
"accuracy"
...
@@ -192,60 +193,6 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
...
@@ -192,60 +193,6 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
return
args
return
args
class
TestFlashAttention3SpeculativeDecodeTopk
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
model
=
DEFAULT_MODEL_NAME_FOR_TEST
@
classmethod
def
get_server_args
(
cls
):
args
=
super
().
get_server_args
()
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"2"
,
"--speculative-algorithm"
,
"EAGLE3"
,
"--speculative-draft"
,
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
,
"--speculative-num-steps"
,
"5"
,
"--speculative-eagle-topk"
,
"4"
,
"--speculative-num-draft-tokens"
,
"8"
,
"--dtype"
,
"float16"
,
]
)
return
args
def
test_gsm8k
(
self
):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
GSM_DATASET_PATH
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
avg_spec_accept_length
=
server_info
.
json
()[
"avg_spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
assertGreater
(
avg_spec_accept_length
,
1.8
)
class
TestFlashAttention3MLASpeculativeDecode
(
BaseFlashAttentionTest
):
class
TestFlashAttention3MLASpeculativeDecode
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
"""Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
...
...
test/srt/test_local_attn.py
View file @
4d23ba08
...
@@ -10,12 +10,13 @@ from sglang.test.test_utils import (
...
@@ -10,12 +10,13 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
,
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
popen_launch_server
,
)
)
@
unittest
.
skipIf
(
get_device_sm
()
<
90
,
"Test requires CUDA SM 90 or higher"
)
@
unittest
.
skipIf
(
get_device_sm
()
<
90
,
"Test requires CUDA SM 90 or higher"
)
class
TestFlashAttention3LocalAttn
(
unittest
.
TestCase
):
class
TestFlashAttention3LocalAttn
(
Custom
TestCase
):
model
=
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
model
=
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
base_url
=
DEFAULT_URL_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
accuracy_threshold
=
0.90
accuracy_threshold
=
0.90
...
@@ -23,7 +24,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
...
@@ -23,7 +24,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
@
classmethod
@
classmethod
def
get_server_args
(
cls
):
def
get_server_args
(
cls
):
return
[
return
[
"--trust-remote-code"
,
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"2"
,
"2"
,
"--attention-backend"
,
"--attention-backend"
,
...
@@ -36,8 +36,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
...
@@ -36,8 +36,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
cls
.
process
=
popen_launch_server
(
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
model
,
cls
.
base_url
,
cls
.
base_url
,
...
@@ -51,6 +49,8 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
...
@@ -51,6 +49,8 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
kill_process_tree
(
cls
.
process
.
pid
)
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
4
,
num_shots
=
4
,
num_questions
=
100
,
num_questions
=
100
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment