Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
00974e4f
Unverified
Commit
00974e4f
authored
Sep 06, 2025
by
Shangming Cai
Committed by
GitHub
Sep 06, 2025
Browse files
[CI] Refactor disaggregation tests (#10068)
Signed-off-by:
Shangming Cai
<
csmthu@gmail.com
>
parent
5f1eb204
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
100 additions
and
353 deletions
+100
-353
python/sglang/test/test_disaggregation_utils.py
python/sglang/test/test_disaggregation_utils.py
+66
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_disaggregation.py
test/srt/test_disaggregation.py
+17
-196
test/srt/test_disaggregation_different_tp.py
test/srt/test_disaggregation_different_tp.py
+5
-112
test/srt/test_disaggregation_pp.py
test/srt/test_disaggregation_pp.py
+11
-45
No files found.
python/sglang/test/test_disaggregation_utils.py
0 → 100644
View file @
00974e4f
import
time
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
CustomTestCase
,
popen_with_error_check
,
)
class
TestDisaggregationBase
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
=
None
,
None
,
None
pass
@
classmethod
def
launch_lb
(
cls
):
lb_command
=
[
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
test/srt/run_suite.py
View file @
00974e4f
...
...
@@ -139,6 +139,7 @@ suites = {
TestFile
(
"lora/test_lora_llama4.py"
,
600
),
TestFile
(
"test_disaggregation.py"
,
499
),
TestFile
(
"test_disaggregation_different_tp.py"
,
155
),
TestFile
(
"test_disaggregation_pp.py"
,
60
),
TestFile
(
"test_full_deepseek_v3.py"
,
333
),
],
"per-commit-8-gpu-b200"
:
[
...
...
test/srt/test_disaggregation.py
View file @
00974e4f
...
...
@@ -7,21 +7,19 @@ from urllib.parse import urlparse
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_disaggregation_utils
import
TestDisaggregationBase
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_pd_server
,
popen_with_error_check
,
)
class
TestDisaggregationAccuracy
(
CustomTestC
ase
):
class
TestDisaggregationAccuracy
(
TestDisaggregationB
ase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
...
@@ -44,25 +42,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
cls
.
launch_lb
()
@
classmethod
def
start_prefill
(
cls
):
...
...
@@ -102,34 +82,6 @@ class TestDisaggregationAccuracy(CustomTestCase):
other_args
=
decode_args
,
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
...
...
@@ -199,7 +151,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
json
.
loads
(
output
)
class
TestDisaggregationMooncakeFailure
(
CustomTestC
ase
):
class
TestDisaggregationMooncakeFailure
(
TestDisaggregationB
ase
):
@
classmethod
def
setUpClass
(
cls
):
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
...
...
@@ -225,25 +177,12 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
cls
.
launch_lb
()
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
def
tearDownClass
(
cls
):
os
.
environ
.
pop
(
"DISAGGREGATION_TEST_FAILURE_PROB"
)
super
().
tearDownClass
()
@
classmethod
def
start_prefill
(
cls
):
...
...
@@ -283,36 +222,6 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
other_args
=
decode_args
,
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
# unset DISAGGREGATION_TEST_FAILURE_PROB
os
.
environ
.
pop
(
"DISAGGREGATION_TEST_FAILURE_PROB"
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
...
...
@@ -341,7 +250,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
raise
e
from
health_check_error
class
TestDisaggregationMooncakeSpec
(
CustomTestC
ase
):
class
TestDisaggregationMooncakeSpec
(
TestDisaggregationB
ase
):
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -380,41 +289,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
cls
.
launch_lb
()
@
classmethod
def
start_prefill
(
cls
):
...
...
@@ -454,18 +329,6 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
other_args
=
decode_args
,
)
@
classmethod
def
tearDownClass
(
cls
):
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
...
...
@@ -482,7 +345,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
class
TestDisaggregationSimulatedRetract
(
CustomTestC
ase
):
class
TestDisaggregationSimulatedRetract
(
TestDisaggregationB
ase
):
@
classmethod
def
setUpClass
(
cls
):
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
...
...
@@ -506,25 +369,12 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
cls
.
launch_lb
()
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
def
tearDownClass
(
cls
):
os
.
environ
.
pop
(
"SGLANG_TEST_RETRACT"
)
super
().
tearDownClass
()
@
classmethod
def
start_prefill
(
cls
):
...
...
@@ -564,35 +414,6 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
other_args
=
decode_args
,
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
os
.
environ
.
pop
(
"SGLANG_TEST_RETRACT"
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
...
...
test/srt/test_disaggregation_different_tp.py
View file @
00974e4f
import
os
import
subprocess
import
time
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_disaggregation_utils
import
TestDisaggregationBase
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_pd_server
,
popen_with_error_check
,
)
class
TestDisaggregationMooncakePrefillLargerTP
(
CustomTestC
ase
):
class
TestDisaggregationMooncakePrefillLargerTP
(
TestDisaggregationB
ase
):
@
classmethod
def
setUpClass
(
cls
):
# Temporarily disable JIT DeepGEMM
...
...
@@ -46,25 +41,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
cls
.
launch_lb
()
@
classmethod
def
start_prefill
(
cls
):
...
...
@@ -104,39 +81,6 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
other_args
=
decode_args
,
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
60
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
# Restore JIT DeepGEMM environment variable
if
cls
.
original_jit_deepgemm
is
not
None
:
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
cls
.
original_jit_deepgemm
else
:
os
.
environ
.
pop
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
None
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
...
...
@@ -153,7 +97,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
class
TestDisaggregationMooncakeDecodeLargerTP
(
CustomTestC
ase
):
class
TestDisaggregationMooncakeDecodeLargerTP
(
TestDisaggregationB
ase
):
@
classmethod
def
setUpClass
(
cls
):
# Temporarily disable JIT DeepGEMM
...
...
@@ -180,25 +124,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
popen_with_error_check
(
lb_command
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
cls
.
launch_lb
()
@
classmethod
def
start_prefill
(
cls
):
...
...
@@ -238,39 +164,6 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
other_args
=
decode_args
,
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
60
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
# Restore JIT DeepGEMM environment variable
if
cls
.
original_jit_deepgemm
is
not
None
:
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
cls
.
original_jit_deepgemm
else
:
os
.
environ
.
pop
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
None
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
...
...
test/srt/test_disaggregation_pp.py
View file @
00974e4f
import
json
import
os
import
random
import
time
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
from
urllib.parse
import
urlparse
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.
runners
import
DEFAULT_PROMPTS
from
sglang.test.
test_disaggregation_utils
import
TestDisaggregationBase
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
popen_launch_pd_server
,
)
class
Test
PDPPAccuracy
(
unittest
.
TestC
ase
):
class
Test
DisaggregationPPAccuracy
(
TestDisaggregationB
ase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
...
@@ -46,27 +36,7 @@ class TestPDPPAccuracy(unittest.TestCase):
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--pd-disaggregation"
,
"--mini-lb"
,
# FIXME: remove this
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
subprocess
.
Popen
(
lb_command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
cls
.
launch_lb
()
@
classmethod
def
start_prefill
(
cls
):
...
...
@@ -75,11 +45,11 @@ class TestPDPPAccuracy(unittest.TestCase):
"--disaggregation-mode"
,
"prefill"
,
"--tp-size"
,
"
2
"
,
"
1
"
,
"--pp-size"
,
"2"
,
"--disaggregation-ib-device"
,
"mlx5_roce0"
,
"mlx5_roce0
,mlx5_roce1
"
,
"--disable-overlap-schedule"
,
]
cls
.
process_prefill
=
popen_launch_pd_server
(
...
...
@@ -98,9 +68,9 @@ class TestPDPPAccuracy(unittest.TestCase):
"--tp"
,
"1"
,
"--base-gpu-id"
,
"
1
"
,
"
2
"
,
"--disaggregation-ib-device"
,
"mlx5_roce
1
"
,
"mlx5_roce
2
"
,
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
...
...
@@ -109,10 +79,6 @@ class TestPDPPAccuracy(unittest.TestCase):
other_args
=
decode_args
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
...
...
@@ -120,8 +86,8 @@ class TestPDPPAccuracy(unittest.TestCase):
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://
127.0.0.1
"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]
),
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
lb_port
),
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment