Unverified Commit c560410d authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

Refactor and optimize mooncake CI (#11162)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent 590f2da0
import time
from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_with_error_check,
)
......@@ -13,8 +15,17 @@ from sglang.test.test_utils import (
class TestDisaggregationBase(CustomTestCase):
@classmethod
def setUpClass(cls):
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
pass
@classmethod
def launch_lb(cls):
......
......@@ -146,12 +146,12 @@ suites = {
],
"per-commit-8-gpu": [
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
TestFile("lora/test_lora_llama4.py", 600),
TestFile("test_disaggregation.py", 499),
TestFile("lora/test_lora_llama4.py", 400),
TestFile("test_disaggregation.py", 600),
TestFile("test_disaggregation_dp_attention.py", 155),
TestFile("test_disaggregation_different_tp.py", 155),
TestFile("test_disaggregation_pp.py", 60),
TestFile("test_full_deepseek_v3.py", 333),
TestFile("test_disaggregation_different_tp.py", 600),
TestFile("test_disaggregation_pp.py", 140),
TestFile("test_full_deepseek_v3.py", 550),
],
"per-commit-4-gpu-b200": [
# TestFile("test_gpt_oss_4gpu.py", 600),
......
......@@ -3,7 +3,6 @@ import os
import time
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
import requests
......@@ -14,7 +13,6 @@ from sglang.test.test_utils import (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_pd_server,
)
......@@ -22,17 +20,8 @@ from sglang.test.test_utils import (
class TestDisaggregationAccuracy(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
......@@ -51,9 +40,9 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
"--disaggregation-mode",
"prefill",
"--tp",
"1",
"2",
"--disaggregation-ib-device",
"mlx5_roce0",
"mlx5_roce0,mlx5_roce1",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
......@@ -69,11 +58,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
"--disaggregation-mode",
"decode",
"--tp",
"1",
"2",
"--base-gpu-id",
"1",
"2",
"--disaggregation-ib-device",
"mlx5_roce1",
"mlx5_roce2,mlx5_roce3",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
......@@ -154,20 +143,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
......@@ -191,9 +171,9 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
"--disaggregation-mode",
"prefill",
"--tp",
"1",
"2",
"--disaggregation-ib-device",
"mlx5_roce0",
"mlx5_roce0,mlx5_roce1",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
......@@ -209,11 +189,11 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
"--disaggregation-mode",
"decode",
"--tp",
"1",
"2",
"--base-gpu-id",
"1",
"2",
"--disaggregation-ib-device",
"mlx5_roce1",
"mlx5_roce2,mlx5_roce3",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
......@@ -254,17 +234,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
cls.spec_args = [
"--speculative-algorithm",
"EAGLE",
......@@ -348,18 +320,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
os.environ["SGLANG_TEST_RETRACT"] = "true"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
......@@ -383,9 +346,9 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
"--disaggregation-mode",
"prefill",
"--tp",
"1",
"2",
"--disaggregation-ib-device",
"mlx5_roce0",
"mlx5_roce0,mlx5_roce1",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
......@@ -401,11 +364,11 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
"--disaggregation-mode",
"decode",
"--tp",
"1",
"2",
"--base-gpu-id",
"1",
"2",
"--disaggregation-ib-device",
"mlx5_roce1",
"mlx5_roce2,mlx5_roce3",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
......
......@@ -2,14 +2,13 @@ import os
import time
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_pd_server,
)
......@@ -17,21 +16,12 @@ from sglang.test.test_utils import (
class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
......@@ -50,7 +40,7 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
"--disaggregation-mode",
"prefill",
"--tp",
"2",
"4",
"--disaggregation-ib-device",
"mlx5_roce0,mlx5_roce1",
]
......@@ -68,11 +58,11 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"2",
"--base-gpu-id",
"4",
"--disaggregation-ib-device",
"mlx5_roce2",
"mlx5_roce4,mlx5_roce5",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
......@@ -100,21 +90,12 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
......@@ -133,9 +114,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
"--disaggregation-mode",
"prefill",
"--tp",
"1",
"2",
"--disaggregation-ib-device",
"mlx5_roce0,mlx5_roce1",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"4",
"--base-gpu-id",
"4",
"--disaggregation-ib-device",
"mlx5_roce4,mlx5_roce5",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.60)
class TestDisaggregationMooncakeMHAPrefillLargerTP(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"4",
"--disaggregation-ib-device",
"mlx5_roce0",
"mlx5_roce0,mlx5_roce1",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
......@@ -153,9 +208,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
"--tp",
"2",
"--base-gpu-id",
"1",
"4",
"--disaggregation-ib-device",
"mlx5_roce4,mlx5_roce5",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.60)
class TestDisaggregationMooncakeMHADecodeLargerTP(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"2",
"--disaggregation-ib-device",
"mlx5_roce0,mlx5_roce1",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"4",
"--base-gpu-id",
"4",
"--disaggregation-ib-device",
"mlx5_roce1,mlx5_roce2",
"mlx5_roce4,mlx5_roce5",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
......
......@@ -17,21 +17,12 @@ from sglang.test.test_utils import (
class TestDisaggregationDPAttention(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
......
import time
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_pd_server,
)
......@@ -16,17 +14,8 @@ from sglang.test.test_utils import (
class TestDisaggregationPPAccuracy(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
......@@ -45,7 +34,7 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase):
"--disaggregation-mode",
"prefill",
"--tp-size",
"1",
"2",
"--pp-size",
"2",
"--disaggregation-ib-device",
......@@ -66,11 +55,11 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase):
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"2",
"--base-gpu-id",
"4",
"--disaggregation-ib-device",
"mlx5_roce2",
"mlx5_roce4,mlx5_roce5",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment