Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4d23ba08
"sgl-kernel/vscode:/vscode.git/clone" did not exist on "988ab646ec4cccb86141a075510ca71b473671b0"
Unverified
Commit
4d23ba08
authored
Apr 27, 2025
by
Lianmin Zheng
Committed by
GitHub
Apr 27, 2025
Browse files
Simplify FA3 tests (#5779)
parent
6e313c1b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
67 deletions
+14
-67
test/srt/run_suite.py
test/srt/run_suite.py
+2
-2
test/srt/test_fa3.py
test/srt/test_fa3.py
+8
-61
test/srt/test_local_attn.py
test/srt/test_local_attn.py
+4
-4
No files found.
test/srt/run_suite.py
View file @
4d23ba08
...
...
@@ -30,7 +30,7 @@ suites = {
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_eagle_infer.py"
,
500
),
TestFile
(
"test_ebnf_constrained.py"
),
TestFile
(
"test_fa3.py"
,
5
00
),
TestFile
(
"test_fa3.py"
,
4
00
),
TestFile
(
"test_fp8_kernel.py"
,
8
),
TestFile
(
"test_embedding_openai_server.py"
,
36
),
TestFile
(
"test_hidden_states.py"
,
55
),
...
...
@@ -92,7 +92,7 @@ suites = {
TestFile
(
"test_verl_engine.py"
,
100
),
],
"per-commit-8-gpu"
:
[
TestFile
(
"test_local_attn.py"
,
10
0
),
TestFile
(
"test_local_attn.py"
,
25
0
),
],
"nightly"
:
[
TestFile
(
"test_nightly_gsm8k_eval.py"
),
...
...
test/srt/test_fa3.py
View file @
4d23ba08
...
...
@@ -3,7 +3,6 @@ import unittest
from
types
import
SimpleNamespace
import
requests
import
torch
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
...
...
@@ -14,6 +13,7 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
...
...
@@ -47,9 +47,8 @@ if OFFLINE_MODE:
# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS
=
[
"--trust-remote-code"
,
"--enable-torch-compile"
,
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--attention-backend"
,
"fa3"
,
]
...
...
@@ -60,7 +59,7 @@ Integration test for python/sglang/srt/layers/attention/flashattention_backend.p
@
unittest
.
skipIf
(
get_device_sm
()
<
90
,
"Test requires CUDA SM 90 or higher"
)
class
BaseFlashAttentionTest
(
unittest
.
TestCase
):
class
BaseFlashAttentionTest
(
Custom
TestCase
):
"""Base class for testing FlashAttention3."""
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
...
@@ -78,13 +77,13 @@ class BaseFlashAttentionTest(unittest.TestCase):
def
setUpClass
(
cls
):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
os
.
environ
[
"SGL_JIT_DEEPGEMM_PRECOMPILE"
]
=
"False"
os
.
environ
[
"SGL_JIT_DEEPGEMM_PRECOMPILE"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
cls
.
get_server_args
(),
env
=
os
.
environ
,
)
@
classmethod
...
...
@@ -92,6 +91,8 @@ class BaseFlashAttentionTest(unittest.TestCase):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
num_shots
=
4
,
num_questions
=
100
,
...
...
@@ -102,7 +103,7 @@ class BaseFlashAttentionTest(unittest.TestCase):
data_path
=
GSM_DATASET_PATH
,
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
print
(
f
"
{
metrics
=
}
"
)
# Use the appropriate metric key based on the test class
metric_key
=
"accuracy"
...
...
@@ -192,60 +193,6 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
return
args
class
TestFlashAttention3SpeculativeDecodeTopk
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
model
=
DEFAULT_MODEL_NAME_FOR_TEST
@
classmethod
def
get_server_args
(
cls
):
args
=
super
().
get_server_args
()
args
.
extend
(
[
"--cuda-graph-max-bs"
,
"2"
,
"--speculative-algorithm"
,
"EAGLE3"
,
"--speculative-draft"
,
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
,
"--speculative-num-steps"
,
"5"
,
"--speculative-eagle-topk"
,
"4"
,
"--speculative-num-draft-tokens"
,
"8"
,
"--dtype"
,
"float16"
,
]
)
return
args
def
test_gsm8k
(
self
):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
GSM_DATASET_PATH
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
avg_spec_accept_length
=
server_info
.
json
()[
"avg_spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
assertGreater
(
avg_spec_accept_length
,
1.8
)
class
TestFlashAttention3MLASpeculativeDecode
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
...
...
test/srt/test_local_attn.py
View file @
4d23ba08
...
...
@@ -10,12 +10,13 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
@
unittest
.
skipIf
(
get_device_sm
()
<
90
,
"Test requires CUDA SM 90 or higher"
)
class
TestFlashAttention3LocalAttn
(
unittest
.
TestCase
):
class
TestFlashAttention3LocalAttn
(
Custom
TestCase
):
model
=
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
base_url
=
DEFAULT_URL_FOR_TEST
accuracy_threshold
=
0.90
...
...
@@ -23,7 +24,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
@
classmethod
def
get_server_args
(
cls
):
return
[
"--trust-remote-code"
,
"--cuda-graph-max-bs"
,
"2"
,
"--attention-backend"
,
...
...
@@ -36,8 +36,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
@
classmethod
def
setUpClass
(
cls
):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
...
...
@@ -51,6 +49,8 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
num_shots
=
4
,
num_questions
=
100
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment