Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
00974e4f
Unverified
Commit
00974e4f
authored
Sep 06, 2025
by
Shangming Cai
Committed by
GitHub
Sep 06, 2025
Browse files
[CI] Refactor disaggregation tests (#10068)
Signed-off-by:
Shangming Cai
<
csmthu@gmail.com
>
parent
5f1eb204
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
100 additions
and
353 deletions
+100
-353
python/sglang/test/test_disaggregation_utils.py
python/sglang/test/test_disaggregation_utils.py
+66
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_disaggregation.py
test/srt/test_disaggregation.py
+17
-196
test/srt/test_disaggregation_different_tp.py
test/srt/test_disaggregation_different_tp.py
+5
-112
test/srt/test_disaggregation_pp.py
test/srt/test_disaggregation_pp.py
+11
-45
No files found.
python/sglang/test/test_disaggregation_utils.py
0 → 100644
View file @
00974e4f
import
time
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
CustomTestCase
,
popen_with_error_check
,
)
class
TestDisaggregationBase
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
=
None
,
None
,
None
pass
@
classmethod
def
launch_lb
(
cls
):
lb_command
=
[
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
test/srt/run_suite.py
View file @
00974e4f
...
@@ -139,6 +139,7 @@ suites = {
...
@@ -139,6 +139,7 @@ suites = {
TestFile
(
"lora/test_lora_llama4.py"
,
600
),
TestFile
(
"lora/test_lora_llama4.py"
,
600
),
TestFile
(
"test_disaggregation.py"
,
499
),
TestFile
(
"test_disaggregation.py"
,
499
),
TestFile
(
"test_disaggregation_different_tp.py"
,
155
),
TestFile
(
"test_disaggregation_different_tp.py"
,
155
),
TestFile
(
"test_disaggregation_pp.py"
,
60
),
TestFile
(
"test_full_deepseek_v3.py"
,
333
),
TestFile
(
"test_full_deepseek_v3.py"
,
333
),
],
],
"per-commit-8-gpu-b200"
:
[
"per-commit-8-gpu-b200"
:
[
...
...
test/srt/test_disaggregation.py
View file @
00974e4f
...
@@ -7,21 +7,19 @@ from urllib.parse import urlparse
...
@@ -7,21 +7,19 @@ from urllib.parse import urlparse
import
requests
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_disaggregation_utils
import
TestDisaggregationBase
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_pd_server
,
popen_launch_pd_server
,
popen_with_error_check
,
)
)
class
TestDisaggregationAccuracy
(
CustomTestC
ase
):
class
TestDisaggregationAccuracy
(
TestDisaggregationB
ase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
@@ -44,25 +42,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -44,25 +42,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
cls
.
launch_lb
()
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
@
classmethod
def
start_prefill
(
cls
):
def
start_prefill
(
cls
):
...
@@ -102,34 +82,6 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -102,34 +82,6 @@ class TestDisaggregationAccuracy(CustomTestCase):
other_args
=
decode_args
,
other_args
=
decode_args
,
)
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
5
,
num_shots
=
5
,
...
@@ -199,7 +151,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -199,7 +151,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
json
.
loads
(
output
)
json
.
loads
(
output
)
class
TestDisaggregationMooncakeFailure
(
CustomTestC
ase
):
class
TestDisaggregationMooncakeFailure
(
TestDisaggregationB
ase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
...
@@ -225,25 +177,12 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
...
@@ -225,25 +177,12 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
cls
.
launch_lb
()
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
@
classmethod
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
def
tearDownClass
(
cls
):
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
os
.
environ
.
pop
(
"DISAGGREGATION_TEST_FAILURE_PROB"
)
super
().
tearDownClass
()
@
classmethod
@
classmethod
def
start_prefill
(
cls
):
def
start_prefill
(
cls
):
...
@@ -283,36 +222,6 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
...
@@ -283,36 +222,6 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
other_args
=
decode_args
,
other_args
=
decode_args
,
)
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
# unset DISAGGREGATION_TEST_FAILURE_PROB
os
.
environ
.
pop
(
"DISAGGREGATION_TEST_FAILURE_PROB"
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
5
,
num_shots
=
5
,
...
@@ -341,7 +250,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
...
@@ -341,7 +250,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
raise
e
from
health_check_error
raise
e
from
health_check_error
class
TestDisaggregationMooncakeSpec
(
CustomTestC
ase
):
class
TestDisaggregationMooncakeSpec
(
TestDisaggregationB
ase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -380,41 +289,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
...
@@ -380,41 +289,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
cls
.
launch_lb
()
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
@
classmethod
def
start_prefill
(
cls
):
def
start_prefill
(
cls
):
...
@@ -454,18 +329,6 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
...
@@ -454,18 +329,6 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
other_args
=
decode_args
,
other_args
=
decode_args
,
)
)
@
classmethod
def
tearDownClass
(
cls
):
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
5
,
num_shots
=
5
,
...
@@ -482,7 +345,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
...
@@ -482,7 +345,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
class
TestDisaggregationSimulatedRetract
(
CustomTestC
ase
):
class
TestDisaggregationSimulatedRetract
(
TestDisaggregationB
ase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
...
@@ -506,25 +369,12 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
...
@@ -506,25 +369,12 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
cls
.
launch_lb
()
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
@
classmethod
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
def
tearDownClass
(
cls
):
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
os
.
environ
.
pop
(
"SGLANG_TEST_RETRACT"
)
super
().
tearDownClass
()
@
classmethod
@
classmethod
def
start_prefill
(
cls
):
def
start_prefill
(
cls
):
...
@@ -564,35 +414,6 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
...
@@ -564,35 +414,6 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
other_args
=
decode_args
,
other_args
=
decode_args
,
)
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
os
.
environ
.
pop
(
"SGLANG_TEST_RETRACT"
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
5
,
num_shots
=
5
,
...
...
test/srt/test_disaggregation_different_tp.py
View file @
00974e4f
import
os
import
os
import
subprocess
import
time
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_disaggregation_utils
import
TestDisaggregationBase
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_pd_server
,
popen_launch_pd_server
,
popen_with_error_check
,
)
)
class
TestDisaggregationMooncakePrefillLargerTP
(
CustomTestC
ase
):
class
TestDisaggregationMooncakePrefillLargerTP
(
TestDisaggregationB
ase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
# Temporarily disable JIT DeepGEMM
# Temporarily disable JIT DeepGEMM
...
@@ -46,25 +41,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
...
@@ -46,25 +41,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
cls
.
launch_lb
()
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
@
classmethod
def
start_prefill
(
cls
):
def
start_prefill
(
cls
):
...
@@ -104,39 +81,6 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
...
@@ -104,39 +81,6 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
other_args
=
decode_args
,
other_args
=
decode_args
,
)
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
60
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
# Restore JIT DeepGEMM environment variable
if
cls
.
original_jit_deepgemm
is
not
None
:
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
cls
.
original_jit_deepgemm
else
:
os
.
environ
.
pop
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
None
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
5
,
num_shots
=
5
,
...
@@ -153,7 +97,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
...
@@ -153,7 +97,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
class
TestDisaggregationMooncakeDecodeLargerTP
(
CustomTestC
ase
):
class
TestDisaggregationMooncakeDecodeLargerTP
(
TestDisaggregationB
ase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
# Temporarily disable JIT DeepGEMM
# Temporarily disable JIT DeepGEMM
...
@@ -180,25 +124,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
...
@@ -180,25 +124,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
cls
.
launch_lb
()
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
@
classmethod
def
start_prefill
(
cls
):
def
start_prefill
(
cls
):
...
@@ -238,39 +164,6 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
...
@@ -238,39 +164,6 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
other_args
=
decode_args
,
other_args
=
decode_args
,
)
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
60
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
# Restore JIT DeepGEMM environment variable
if
cls
.
original_jit_deepgemm
is
not
None
:
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
cls
.
original_jit_deepgemm
else
:
os
.
environ
.
pop
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
None
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
5
,
num_shots
=
5
,
...
...
test/srt/test_disaggregation_pp.py
View file @
00974e4f
import
json
import
os
import
random
import
time
import
time
import
unittest
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
from
urllib.parse
import
urlparse
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.
runners
import
DEFAULT_PROMPTS
from
sglang.test.
test_disaggregation_utils
import
TestDisaggregationBase
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_pd_server
,
popen_launch_server
,
)
)
class
Test
PDPPAccuracy
(
unittest
.
TestC
ase
):
class
Test
DisaggregationPPAccuracy
(
TestDisaggregationB
ase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
@@ -46,27 +36,7 @@ class TestPDPPAccuracy(unittest.TestCase):
...
@@ -46,27 +36,7 @@ class TestPDPPAccuracy(unittest.TestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
cls
.
launch_lb
()
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
subprocess
.
Popen
(
lb_command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
@
classmethod
def
start_prefill
(
cls
):
def
start_prefill
(
cls
):
...
@@ -75,11 +45,11 @@ class TestPDPPAccuracy(unittest.TestCase):
...
@@ -75,11 +45,11 @@ class TestPDPPAccuracy(unittest.TestCase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"prefill"
,
"prefill"
,
"--tp-size"
,
"--tp-size"
,
"
2
"
,
"
1
"
,
"--pp-size"
,
"--pp-size"
,
"2"
,
"2"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce0"
,
"mlx5_roce0
,mlx5_roce1
"
,
"--disable-overlap-schedule"
,
"--disable-overlap-schedule"
,
]
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
process_prefill
=
popen_launch_pd_server
(
...
@@ -98,9 +68,9 @@ class TestPDPPAccuracy(unittest.TestCase):
...
@@ -98,9 +68,9 @@ class TestPDPPAccuracy(unittest.TestCase):
"--tp"
,
"--tp"
,
"1"
,
"1"
,
"--base-gpu-id"
,
"--base-gpu-id"
,
"
1
"
,
"
2
"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce
1
"
,
"mlx5_roce
2
"
,
]
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -109,10 +79,6 @@ class TestPDPPAccuracy(unittest.TestCase):
...
@@ -109,10 +79,6 @@ class TestPDPPAccuracy(unittest.TestCase):
other_args
=
decode_args
,
other_args
=
decode_args
,
)
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
5
,
num_shots
=
5
,
...
@@ -120,8 +86,8 @@ class TestPDPPAccuracy(unittest.TestCase):
...
@@ -120,8 +86,8 @@ class TestPDPPAccuracy(unittest.TestCase):
num_questions
=
200
,
num_questions
=
200
,
max_new_tokens
=
512
,
max_new_tokens
=
512
,
parallel
=
128
,
parallel
=
128
,
host
=
"http://
127.0.0.1
"
,
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]
),
port
=
int
(
self
.
lb_port
),
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
print
(
f
"
{
metrics
=
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment