Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a45a4b23
Unverified
Commit
a45a4b23
authored
Apr 27, 2025
by
Baizhou Zhang
Committed by
GitHub
Apr 27, 2025
Browse files
Split local attention test from fa3 test (#5774)
parent
981a2619
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
18 deletions
+74
-18
test/srt/run_suite.py
test/srt/run_suite.py
+2
-1
test/srt/test_fa3.py
test/srt/test_fa3.py
+0
-17
test/srt/test_local_attn.py
test/srt/test_local_attn.py
+72
-0
No files found.
test/srt/run_suite.py
View file @
a45a4b23
...
@@ -30,6 +30,7 @@ suites = {
...
@@ -30,6 +30,7 @@ suites = {
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_eagle_infer.py"
,
500
),
TestFile
(
"test_eagle_infer.py"
,
500
),
TestFile
(
"test_ebnf_constrained.py"
),
TestFile
(
"test_ebnf_constrained.py"
),
TestFile
(
"test_fa3.py"
,
500
),
TestFile
(
"test_fp8_kernel.py"
,
8
),
TestFile
(
"test_fp8_kernel.py"
,
8
),
TestFile
(
"test_embedding_openai_server.py"
,
36
),
TestFile
(
"test_embedding_openai_server.py"
,
36
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_hidden_states.py"
,
55
),
...
@@ -91,7 +92,7 @@ suites = {
...
@@ -91,7 +92,7 @@ suites = {
TestFile
(
"test_verl_engine.py"
,
100
),
TestFile
(
"test_verl_engine.py"
,
100
),
],
],
"per-commit-8-gpu"
:
[
"per-commit-8-gpu"
:
[
TestFile
(
"test_
fa3
.py"
,
3
0
),
TestFile
(
"test_
local_attn
.py"
,
10
0
),
],
],
"nightly"
:
[
"nightly"
:
[
TestFile
(
"test_nightly_gsm8k_eval.py"
),
TestFile
(
"test_nightly_gsm8k_eval.py"
),
...
...
test/srt/test_fa3.py
View file @
a45a4b23
...
@@ -10,7 +10,6 @@ from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
...
@@ -10,7 +10,6 @@ from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
,
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
,
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
@@ -127,22 +126,6 @@ class TestFlashAttention3MLA(BaseFlashAttentionTest):
...
@@ -127,22 +126,6 @@ class TestFlashAttention3MLA(BaseFlashAttentionTest):
return
DEFAULT_SERVER_ARGS
return
DEFAULT_SERVER_ARGS
class
TestFlashAttention3LocalAttn
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with Model with local attention, e.g. Llama 4."""
accuracy_threshold
=
0.70
model
=
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
@
classmethod
def
get_server_args
(
cls
):
cloned_args
=
DEFAULT_SERVER_ARGS
.
copy
()
# remove --enable-torch-compile from cloned_args since llama4 does not support it for now
cloned_args
.
remove
(
"--enable-torch-compile"
)
# we cannot use scout's 10m context due to this bug: https://github.com/sgl-project/sglang/issues/5755
cloned_args
.
extend
([
"--tp"
,
"4"
,
"--context-length"
,
"1000000"
])
return
cloned_args
class
TestFlashAttention3SpeculativeDecode
(
BaseFlashAttentionTest
):
class
TestFlashAttention3SpeculativeDecode
(
BaseFlashAttentionTest
):
"""Test FlashAttention3 with speculative decode enabled with Llama 3.1 8B and its eagle3 model"""
"""Test FlashAttention3 with speculative decode enabled with Llama 3.1 8B and its eagle3 model"""
...
...
test/srt/test_local_attn.py
0 → 100644
View file @
a45a4b23
import
os
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
@
unittest
.
skipIf
(
get_device_sm
()
<
90
,
"Test requires CUDA SM 90 or higher"
)
class
TestFlashAttention3LocalAttn
(
unittest
.
TestCase
):
model
=
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
base_url
=
DEFAULT_URL_FOR_TEST
accuracy_threshold
=
0.90
@
classmethod
def
get_server_args
(
cls
):
return
[
"--trust-remote-code"
,
"--cuda-graph-max-bs"
,
"2"
,
"--attention-backend"
,
"fa3"
,
"--tp"
,
"4"
,
"--context-length"
,
"1000000"
,
]
@
classmethod
def
setUpClass
(
cls
):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
cls
.
get_server_args
(),
env
=
os
.
environ
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
4
,
num_questions
=
100
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
data_path
=
None
,
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"
{
metrics
=
}
"
)
# Use the appropriate metric key based on the test class
metric_key
=
"accuracy"
self
.
assertGreater
(
metrics
[
metric_key
],
self
.
accuracy_threshold
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment