Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c560410d
Unverified
Commit
c560410d
authored
Oct 06, 2025
by
Shangming Cai
Committed by
GitHub
Oct 05, 2025
Browse files
Refactor and optimize mooncake CI (#11162)
Signed-off-by:
Shangming Cai
<
csmthu@gmail.com
>
parent
590f2da0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
201 additions
and
118 deletions
+201
-118
python/sglang/test/test_disaggregation_utils.py
python/sglang/test/test_disaggregation_utils.py
+12
-1
test/srt/run_suite.py
test/srt/run_suite.py
+5
-5
test/srt/test_disaggregation.py
test/srt/test_disaggregation.py
+19
-56
test/srt/test_disaggregation_different_tp.py
test/srt/test_disaggregation_different_tp.py
+159
-30
test/srt/test_disaggregation_dp_attention.py
test/srt/test_disaggregation_dp_attention.py
+1
-10
test/srt/test_disaggregation_pp.py
test/srt/test_disaggregation_pp.py
+5
-16
No files found.
python/sglang/test/test_disaggregation_utils.py
View file @
c560410d
import
time
import
time
from
urllib.parse
import
urlparse
import
requests
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
CustomTestCase
,
popen_with_error_check
,
popen_with_error_check
,
)
)
...
@@ -13,8 +15,17 @@ from sglang.test.test_utils import (
...
@@ -13,8 +15,17 @@ from sglang.test.test_utils import (
class
TestDisaggregationBase
(
CustomTestCase
):
class
TestDisaggregationBase
(
CustomTestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
=
None
,
None
,
None
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
=
None
,
None
,
None
pass
@
classmethod
@
classmethod
def
launch_lb
(
cls
):
def
launch_lb
(
cls
):
...
...
test/srt/run_suite.py
View file @
c560410d
...
@@ -146,12 +146,12 @@ suites = {
...
@@ -146,12 +146,12 @@ suites = {
],
],
"per-commit-8-gpu"
:
[
"per-commit-8-gpu"
:
[
TestFile
(
"hicache/test_hicache_storage_mooncake_backend.py"
,
400
),
TestFile
(
"hicache/test_hicache_storage_mooncake_backend.py"
,
400
),
TestFile
(
"lora/test_lora_llama4.py"
,
6
00
),
TestFile
(
"lora/test_lora_llama4.py"
,
4
00
),
TestFile
(
"test_disaggregation.py"
,
499
),
TestFile
(
"test_disaggregation.py"
,
600
),
TestFile
(
"test_disaggregation_dp_attention.py"
,
155
),
TestFile
(
"test_disaggregation_dp_attention.py"
,
155
),
TestFile
(
"test_disaggregation_different_tp.py"
,
155
),
TestFile
(
"test_disaggregation_different_tp.py"
,
600
),
TestFile
(
"test_disaggregation_pp.py"
,
6
0
),
TestFile
(
"test_disaggregation_pp.py"
,
14
0
),
TestFile
(
"test_full_deepseek_v3.py"
,
333
),
TestFile
(
"test_full_deepseek_v3.py"
,
550
),
],
],
"per-commit-4-gpu-b200"
:
[
"per-commit-4-gpu-b200"
:
[
# TestFile("test_gpt_oss_4gpu.py", 600),
# TestFile("test_gpt_oss_4gpu.py", 600),
...
...
test/srt/test_disaggregation.py
View file @
c560410d
...
@@ -3,7 +3,6 @@ import os
...
@@ -3,7 +3,6 @@ import os
import
time
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
import
requests
import
requests
...
@@ -14,7 +13,6 @@ from sglang.test.test_utils import (
...
@@ -14,7 +13,6 @@ from sglang.test.test_utils import (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_pd_server
,
popen_launch_pd_server
,
)
)
...
@@ -22,17 +20,8 @@ from sglang.test.test_utils import (
...
@@ -22,17 +20,8 @@ from sglang.test.test_utils import (
class
TestDisaggregationAccuracy
(
TestDisaggregationBase
):
class
TestDisaggregationAccuracy
(
TestDisaggregationBase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_prefill
()
...
@@ -51,9 +40,9 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
...
@@ -51,9 +40,9 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"prefill"
,
"prefill"
,
"--tp"
,
"--tp"
,
"
1
"
,
"
2
"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce0"
,
"mlx5_roce0
,mlx5_roce1
"
,
]
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -69,11 +58,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
...
@@ -69,11 +58,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"decode"
,
"decode"
,
"--tp"
,
"--tp"
,
"
1
"
,
"
2
"
,
"--base-gpu-id"
,
"--base-gpu-id"
,
"
1
"
,
"
2
"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce
1
"
,
"mlx5_roce
2,mlx5_roce3
"
,
]
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -154,20 +143,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
...
@@ -154,20 +143,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
class
TestDisaggregationMooncakeFailure
(
TestDisaggregationBase
):
class
TestDisaggregationMooncakeFailure
(
TestDisaggregationBase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
super
().
setUpClass
()
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
os
.
environ
[
"DISAGGREGATION_TEST_FAILURE_PROB"
]
=
"0.05"
os
.
environ
[
"DISAGGREGATION_TEST_FAILURE_PROB"
]
=
"0.05"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_prefill
()
...
@@ -191,9 +171,9 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
...
@@ -191,9 +171,9 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"prefill"
,
"prefill"
,
"--tp"
,
"--tp"
,
"
1
"
,
"
2
"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce0"
,
"mlx5_roce0
,mlx5_roce1
"
,
]
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -209,11 +189,11 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
...
@@ -209,11 +189,11 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"decode"
,
"decode"
,
"--tp"
,
"--tp"
,
"
1
"
,
"
2
"
,
"--base-gpu-id"
,
"--base-gpu-id"
,
"
1
"
,
"
2
"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce
1
"
,
"mlx5_roce
2,mlx5_roce3
"
,
]
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -254,17 +234,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
...
@@ -254,17 +234,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls
.
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls
.
draft_model
=
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
cls
.
draft_model
=
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
cls
.
spec_args
=
[
cls
.
spec_args
=
[
"--speculative-algorithm"
,
"--speculative-algorithm"
,
"EAGLE"
,
"EAGLE"
,
...
@@ -348,18 +320,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
...
@@ -348,18 +320,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
class
TestDisaggregationSimulatedRetract
(
TestDisaggregationBase
):
class
TestDisaggregationSimulatedRetract
(
TestDisaggregationBase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
super
().
setUpClass
()
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_prefill
()
...
@@ -383,9 +346,9 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
...
@@ -383,9 +346,9 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"prefill"
,
"prefill"
,
"--tp"
,
"--tp"
,
"
1
"
,
"
2
"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce0"
,
"mlx5_roce0
,mlx5_roce1
"
,
]
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -401,11 +364,11 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
...
@@ -401,11 +364,11 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"decode"
,
"decode"
,
"--tp"
,
"--tp"
,
"
1
"
,
"
2
"
,
"--base-gpu-id"
,
"--base-gpu-id"
,
"
1
"
,
"
2
"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce
1
"
,
"mlx5_roce
2,mlx5_roce3
"
,
]
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
...
test/srt/test_disaggregation_different_tp.py
View file @
c560410d
...
@@ -2,14 +2,13 @@ import os
...
@@ -2,14 +2,13 @@ import os
import
time
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_disaggregation_utils
import
TestDisaggregationBase
from
sglang.test.test_disaggregation_utils
import
TestDisaggregationBase
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_pd_server
,
popen_launch_pd_server
,
)
)
...
@@ -17,21 +16,12 @@ from sglang.test.test_utils import (
...
@@ -17,21 +16,12 @@ from sglang.test.test_utils import (
class
TestDisaggregationMooncakePrefillLargerTP
(
TestDisaggregationBase
):
class
TestDisaggregationMooncakePrefillLargerTP
(
TestDisaggregationBase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
super
().
setUpClass
()
# Temporarily disable JIT DeepGEMM
# Temporarily disable JIT DeepGEMM
cls
.
original_jit_deepgemm
=
os
.
environ
.
get
(
"SGL_ENABLE_JIT_DEEPGEMM"
)
cls
.
original_jit_deepgemm
=
os
.
environ
.
get
(
"SGL_ENABLE_JIT_DEEPGEMM"
)
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST_MLA
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_prefill
()
...
@@ -50,7 +40,7 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
...
@@ -50,7 +40,7 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"prefill"
,
"prefill"
,
"--tp"
,
"--tp"
,
"
2
"
,
"
4
"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce0,mlx5_roce1"
,
"mlx5_roce0,mlx5_roce1"
,
]
]
...
@@ -68,11 +58,11 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
...
@@ -68,11 +58,11 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"decode"
,
"decode"
,
"--tp"
,
"--tp"
,
"1"
,
"--base-gpu-id"
,
"2"
,
"2"
,
"--base-gpu-id"
,
"4"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce
2
"
,
"mlx5_roce
4,mlx5_roce5
"
,
]
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -100,21 +90,12 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
...
@@ -100,21 +90,12 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
class
TestDisaggregationMooncakeDecodeLargerTP
(
TestDisaggregationBase
):
class
TestDisaggregationMooncakeDecodeLargerTP
(
TestDisaggregationBase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
super
().
setUpClass
()
# Temporarily disable JIT DeepGEMM
# Temporarily disable JIT DeepGEMM
cls
.
original_jit_deepgemm
=
os
.
environ
.
get
(
"SGL_ENABLE_JIT_DEEPGEMM"
)
cls
.
original_jit_deepgemm
=
os
.
environ
.
get
(
"SGL_ENABLE_JIT_DEEPGEMM"
)
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST_MLA
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_prefill
()
...
@@ -133,9 +114,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
...
@@ -133,9 +114,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"prefill"
,
"prefill"
,
"--tp"
,
"--tp"
,
"1"
,
"2"
,
"--disaggregation-ib-device"
,
"mlx5_roce0,mlx5_roce1"
,
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
prefill_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
prefill_args
,
)
@
classmethod
def
start_decode
(
cls
):
decode_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"decode"
,
"--tp"
,
"4"
,
"--base-gpu-id"
,
"4"
,
"--disaggregation-ib-device"
,
"mlx5_roce4,mlx5_roce5"
,
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
decode_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
decode_args
,
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
lb_port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
class
TestDisaggregationMooncakeMHAPrefillLargerTP
(
TestDisaggregationBase
):
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
# Temporarily disable JIT DeepGEMM
cls
.
original_jit_deepgemm
=
os
.
environ
.
get
(
"SGL_ENABLE_JIT_DEEPGEMM"
)
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_decode
()
# Block until both
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
launch_lb
()
@
classmethod
def
start_prefill
(
cls
):
prefill_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"prefill"
,
"--tp"
,
"4"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce0"
,
"mlx5_roce0
,mlx5_roce1
"
,
]
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -153,9 +208,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
...
@@ -153,9 +208,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
"--tp"
,
"--tp"
,
"2"
,
"2"
,
"--base-gpu-id"
,
"--base-gpu-id"
,
"1"
,
"4"
,
"--disaggregation-ib-device"
,
"mlx5_roce4,mlx5_roce5"
,
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
decode_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
decode_args
,
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
lb_port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
class
TestDisaggregationMooncakeMHADecodeLargerTP
(
TestDisaggregationBase
):
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
# Temporarily disable JIT DeepGEMM
cls
.
original_jit_deepgemm
=
os
.
environ
.
get
(
"SGL_ENABLE_JIT_DEEPGEMM"
)
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_decode
()
# Block until both
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
launch_lb
()
@
classmethod
def
start_prefill
(
cls
):
prefill_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"prefill"
,
"--tp"
,
"2"
,
"--disaggregation-ib-device"
,
"mlx5_roce0,mlx5_roce1"
,
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
prefill_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
prefill_args
,
)
@
classmethod
def
start_decode
(
cls
):
decode_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"decode"
,
"--tp"
,
"4"
,
"--base-gpu-id"
,
"4"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce
1
,mlx5_roce
2
"
,
"mlx5_roce
4
,mlx5_roce
5
"
,
]
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
...
test/srt/test_disaggregation_dp_attention.py
View file @
c560410d
...
@@ -17,21 +17,12 @@ from sglang.test.test_utils import (
...
@@ -17,21 +17,12 @@ from sglang.test.test_utils import (
class
TestDisaggregationDPAttention
(
TestDisaggregationBase
):
class
TestDisaggregationDPAttention
(
TestDisaggregationBase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
super
().
setUpClass
()
# Temporarily disable JIT DeepGEMM
# Temporarily disable JIT DeepGEMM
cls
.
original_jit_deepgemm
=
os
.
environ
.
get
(
"SGL_ENABLE_JIT_DEEPGEMM"
)
cls
.
original_jit_deepgemm
=
os
.
environ
.
get
(
"SGL_ENABLE_JIT_DEEPGEMM"
)
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST_MLA
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_prefill
()
...
...
test/srt/test_disaggregation_pp.py
View file @
c560410d
import
time
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.test_disaggregation_utils
import
TestDisaggregationBase
from
sglang.test.test_disaggregation_utils
import
TestDisaggregationBase
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_pd_server
,
popen_launch_pd_server
,
)
)
...
@@ -16,17 +14,8 @@ from sglang.test.test_utils import (
...
@@ -16,17 +14,8 @@ from sglang.test.test_utils import (
class
TestDisaggregationPPAccuracy
(
TestDisaggregationBase
):
class
TestDisaggregationPPAccuracy
(
TestDisaggregationBase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_prefill
()
...
@@ -45,7 +34,7 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase):
...
@@ -45,7 +34,7 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"prefill"
,
"prefill"
,
"--tp-size"
,
"--tp-size"
,
"
1
"
,
"
2
"
,
"--pp-size"
,
"--pp-size"
,
"2"
,
"2"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
...
@@ -66,11 +55,11 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase):
...
@@ -66,11 +55,11 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase):
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"decode"
,
"decode"
,
"--tp"
,
"--tp"
,
"1"
,
"--base-gpu-id"
,
"2"
,
"2"
,
"--base-gpu-id"
,
"4"
,
"--disaggregation-ib-device"
,
"--disaggregation-ib-device"
,
"mlx5_roce
2
"
,
"mlx5_roce
4,mlx5_roce5
"
,
]
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment