Unverified Commit c560410d authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

Refactor and optimize mooncake CI (#11162)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent 590f2da0
import time import time
from urllib.parse import urlparse
import requests import requests
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase, CustomTestCase,
popen_with_error_check, popen_with_error_check,
) )
...@@ -13,8 +15,17 @@ from sglang.test.test_utils import ( ...@@ -13,8 +15,17 @@ from sglang.test.test_utils import (
class TestDisaggregationBase(CustomTestCase): class TestDisaggregationBase(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
pass
@classmethod @classmethod
def launch_lb(cls): def launch_lb(cls):
......
...@@ -146,12 +146,12 @@ suites = { ...@@ -146,12 +146,12 @@ suites = {
], ],
"per-commit-8-gpu": [ "per-commit-8-gpu": [
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400), TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
TestFile("lora/test_lora_llama4.py", 600), TestFile("lora/test_lora_llama4.py", 400),
TestFile("test_disaggregation.py", 499), TestFile("test_disaggregation.py", 600),
TestFile("test_disaggregation_dp_attention.py", 155), TestFile("test_disaggregation_dp_attention.py", 155),
TestFile("test_disaggregation_different_tp.py", 155), TestFile("test_disaggregation_different_tp.py", 600),
TestFile("test_disaggregation_pp.py", 60), TestFile("test_disaggregation_pp.py", 140),
TestFile("test_full_deepseek_v3.py", 333), TestFile("test_full_deepseek_v3.py", 550),
], ],
"per-commit-4-gpu-b200": [ "per-commit-4-gpu-b200": [
# TestFile("test_gpt_oss_4gpu.py", 600), # TestFile("test_gpt_oss_4gpu.py", 600),
......
...@@ -3,7 +3,6 @@ import os ...@@ -3,7 +3,6 @@ import os
import time import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from urllib.parse import urlparse
import requests import requests
...@@ -14,7 +13,6 @@ from sglang.test.test_utils import ( ...@@ -14,7 +13,6 @@ from sglang.test.test_utils import (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_pd_server, popen_launch_pd_server,
) )
...@@ -22,17 +20,8 @@ from sglang.test.test_utils import ( ...@@ -22,17 +20,8 @@ from sglang.test.test_utils import (
class TestDisaggregationAccuracy(TestDisaggregationBase): class TestDisaggregationAccuracy(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers # Non blocking start servers
cls.start_prefill() cls.start_prefill()
...@@ -51,9 +40,9 @@ class TestDisaggregationAccuracy(TestDisaggregationBase): ...@@ -51,9 +40,9 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"prefill", "prefill",
"--tp", "--tp",
"1", "2",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce0", "mlx5_roce0,mlx5_roce1",
] ]
cls.process_prefill = popen_launch_pd_server( cls.process_prefill = popen_launch_pd_server(
cls.model, cls.model,
...@@ -69,11 +58,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase): ...@@ -69,11 +58,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"decode", "decode",
"--tp", "--tp",
"1", "2",
"--base-gpu-id", "--base-gpu-id",
"1", "2",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce1", "mlx5_roce2,mlx5_roce3",
] ]
cls.process_decode = popen_launch_pd_server( cls.process_decode = popen_launch_pd_server(
cls.model, cls.model,
...@@ -154,20 +143,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase): ...@@ -154,20 +143,11 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
class TestDisaggregationMooncakeFailure(TestDisaggregationBase): class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass()
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure # set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05" os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers # Non blocking start servers
cls.start_prefill() cls.start_prefill()
...@@ -191,9 +171,9 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase): ...@@ -191,9 +171,9 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"prefill", "prefill",
"--tp", "--tp",
"1", "2",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce0", "mlx5_roce0,mlx5_roce1",
] ]
cls.process_prefill = popen_launch_pd_server( cls.process_prefill = popen_launch_pd_server(
cls.model, cls.model,
...@@ -209,11 +189,11 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase): ...@@ -209,11 +189,11 @@ class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"decode", "decode",
"--tp", "--tp",
"1", "2",
"--base-gpu-id", "--base-gpu-id",
"1", "2",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce1", "mlx5_roce2,mlx5_roce3",
] ]
cls.process_decode = popen_launch_pd_server( cls.process_decode = popen_launch_pd_server(
cls.model, cls.model,
...@@ -254,17 +234,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase): ...@@ -254,17 +234,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
cls.spec_args = [ cls.spec_args = [
"--speculative-algorithm", "--speculative-algorithm",
"EAGLE", "EAGLE",
...@@ -348,18 +320,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase): ...@@ -348,18 +320,9 @@ class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
class TestDisaggregationSimulatedRetract(TestDisaggregationBase): class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass()
os.environ["SGLANG_TEST_RETRACT"] = "true" os.environ["SGLANG_TEST_RETRACT"] = "true"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers # Non blocking start servers
cls.start_prefill() cls.start_prefill()
...@@ -383,9 +346,9 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase): ...@@ -383,9 +346,9 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"prefill", "prefill",
"--tp", "--tp",
"1", "2",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce0", "mlx5_roce0,mlx5_roce1",
] ]
cls.process_prefill = popen_launch_pd_server( cls.process_prefill = popen_launch_pd_server(
cls.model, cls.model,
...@@ -401,11 +364,11 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase): ...@@ -401,11 +364,11 @@ class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"decode", "decode",
"--tp", "--tp",
"1", "2",
"--base-gpu-id", "--base-gpu-id",
"1", "2",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce1", "mlx5_roce2,mlx5_roce3",
] ]
cls.process_decode = popen_launch_pd_server( cls.process_decode = popen_launch_pd_server(
cls.model, cls.model,
......
...@@ -2,14 +2,13 @@ import os ...@@ -2,14 +2,13 @@ import os
import time import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_pd_server, popen_launch_pd_server,
) )
...@@ -17,21 +16,12 @@ from sglang.test.test_utils import ( ...@@ -17,21 +16,12 @@ from sglang.test.test_utils import (
class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM # Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers # Non blocking start servers
cls.start_prefill() cls.start_prefill()
...@@ -50,7 +40,7 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): ...@@ -50,7 +40,7 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"prefill", "prefill",
"--tp", "--tp",
"2", "4",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce0,mlx5_roce1", "mlx5_roce0,mlx5_roce1",
] ]
...@@ -68,11 +58,11 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): ...@@ -68,11 +58,11 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"decode", "decode",
"--tp", "--tp",
"1",
"--base-gpu-id",
"2", "2",
"--base-gpu-id",
"4",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce2", "mlx5_roce4,mlx5_roce5",
] ]
cls.process_decode = popen_launch_pd_server( cls.process_decode = popen_launch_pd_server(
cls.model, cls.model,
...@@ -100,21 +90,12 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): ...@@ -100,21 +90,12 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM # Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers # Non blocking start servers
cls.start_prefill() cls.start_prefill()
...@@ -133,9 +114,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): ...@@ -133,9 +114,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"prefill", "prefill",
"--tp", "--tp",
"1", "2",
"--disaggregation-ib-device",
"mlx5_roce0,mlx5_roce1",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"4",
"--base-gpu-id",
"4",
"--disaggregation-ib-device",
"mlx5_roce4,mlx5_roce5",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.60)
class TestDisaggregationMooncakeMHAPrefillLargerTP(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"4",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce0", "mlx5_roce0,mlx5_roce1",
] ]
cls.process_prefill = popen_launch_pd_server( cls.process_prefill = popen_launch_pd_server(
cls.model, cls.model,
...@@ -153,9 +208,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): ...@@ -153,9 +208,83 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
"--tp", "--tp",
"2", "2",
"--base-gpu-id", "--base-gpu-id",
"1", "4",
"--disaggregation-ib-device",
"mlx5_roce4,mlx5_roce5",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.60)
class TestDisaggregationMooncakeMHADecodeLargerTP(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"2",
"--disaggregation-ib-device",
"mlx5_roce0,mlx5_roce1",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"4",
"--base-gpu-id",
"4",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce1,mlx5_roce2", "mlx5_roce4,mlx5_roce5",
] ]
cls.process_decode = popen_launch_pd_server( cls.process_decode = popen_launch_pd_server(
cls.model, cls.model,
......
...@@ -17,21 +17,12 @@ from sglang.test.test_utils import ( ...@@ -17,21 +17,12 @@ from sglang.test.test_utils import (
class TestDisaggregationDPAttention(TestDisaggregationBase): class TestDisaggregationDPAttention(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM # Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers # Non blocking start servers
cls.start_prefill() cls.start_prefill()
......
import time import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.test.few_shot_gsm8k import run_eval from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_pd_server, popen_launch_pd_server,
) )
...@@ -16,17 +14,8 @@ from sglang.test.test_utils import ( ...@@ -16,17 +14,8 @@ from sglang.test.test_utils import (
class TestDisaggregationPPAccuracy(TestDisaggregationBase): class TestDisaggregationPPAccuracy(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers # Non blocking start servers
cls.start_prefill() cls.start_prefill()
...@@ -45,7 +34,7 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase): ...@@ -45,7 +34,7 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"prefill", "prefill",
"--tp-size", "--tp-size",
"1", "2",
"--pp-size", "--pp-size",
"2", "2",
"--disaggregation-ib-device", "--disaggregation-ib-device",
...@@ -66,11 +55,11 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase): ...@@ -66,11 +55,11 @@ class TestDisaggregationPPAccuracy(TestDisaggregationBase):
"--disaggregation-mode", "--disaggregation-mode",
"decode", "decode",
"--tp", "--tp",
"1",
"--base-gpu-id",
"2", "2",
"--base-gpu-id",
"4",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce2", "mlx5_roce4,mlx5_roce5",
] ]
cls.process_decode = popen_launch_pd_server( cls.process_decode = popen_launch_pd_server(
cls.model, cls.model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment