import json import os import random import time import unittest from concurrent.futures import ThreadPoolExecutor from types import SimpleNamespace from typing import List, Optional import requests from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval from sglang.test.runners import DEFAULT_PROMPTS from sglang.test.test_utils import ( DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_server, ) class TestPDPPAccuracy(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST parsed_url = urlparse(DEFAULT_URL_FOR_TEST) cls.base_host = parsed_url.hostname base_port = str(parsed_url.port) cls.lb_port = base_port cls.prefill_port = f"{int(base_port) + 100}" cls.decode_port = f"{int(base_port) + 200}" cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") # Non blocking start servers cls.start_prefill() cls.start_decode() # Block until both cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") lb_command = [ "python3", "-m", "sglang_router.launch_router", "--pd-disaggregation", "--mini-lb", # FIXME: remove this "--prefill", cls.prefill_url, "--decode", cls.decode_url, "--host", cls.base_host, "--port", cls.lb_port, ] print("Starting load balancer:", " ".join(lb_command)) cls.process_lb = subprocess.Popen( lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) cls.wait_server_ready(cls.lb_url + "/health") @classmethod def start_prefill(cls): prefill_args = [ "--trust-remote-code", "--disaggregation-mode", "prefill", "--tp-size", "2", "--pp-size", "2", "--disaggregation-ib-device", "mlx5_roce0", "--disable-overlap-schedule", ] cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=prefill_args, ) @classmethod def start_decode(cls): decode_args = [ "--trust-remote-code", "--disaggregation-mode", "decode", "--tp", "1", "--base-gpu-id", "1", "--disaggregation-ib-device", "mlx5_roce1", ] cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=decode_args, ) @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_gsm8k(self): args = SimpleNamespace( num_shots=5, data_path=None, num_questions=200, max_new_tokens=512, parallel=128, host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), ) metrics = run_eval(args) print(f"{metrics=}") self.assertGreater(metrics["accuracy"], 0.24) # Wait a little bit so that the memory check happens. time.sleep(5) if __name__ == "__main__": unittest.main()