launch_server.py 4.33 KB
Newer Older
1
2
import argparse
import copy
3
import logging
4
5
import multiprocessing as mp
import os
Byron Hsu's avatar
Byron Hsu committed
6
import random
7
8
9
10
11
12
import signal
import sys
import time
from typing import List

import requests
13
from setproctitle import setproctitle
14
15
16
from sglang_router.launch_router import RouterArgs, launch_router

from sglang.srt.server import launch_server
Byron Hsu's avatar
Byron Hsu committed
17
from sglang.srt.server_args import ServerArgs
18
19
20
from sglang.srt.utils import is_port_available


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def setup_logger():
    logger = logging.getLogger("router")
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter(
        "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    return logger


37
38
39
logger = setup_logger()


40
41
# Create new process group
def run_server(server_args, dp_rank):
42
    setproctitle(f"sglang::server")
43
44
    # Set SGLANG_DP_RANK environment variable
    os.environ["SGLANG_DP_RANK"] = str(dp_rank)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    launch_server(server_args)


def launch_server_process(
    server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
    """Launch a single server process with the given args and port."""
    server_args = copy.deepcopy(server_args)
    server_args.port = worker_port
    server_args.base_gpu_id = dp_id * server_args.tp_size
    server_args.dp_size = 1

    proc = mp.Process(target=run_server, args=(server_args, dp_id))
    proc.start()
    return proc


def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
    """Wait for server to be healthy by checking /health endpoint."""
    start_time = time.time()
    url = f"http://{host}:{port}/health"

    while time.time() - start_time < timeout:
        try:
            response = requests.get(url, timeout=5)
            if response.status_code == 200:
                return True
        except requests.exceptions.RequestException:
            pass
        time.sleep(1)
    return False


def find_available_ports(base_port: int, count: int) -> List[int]:
    """Find consecutive available ports starting from base_port."""
    available_ports = []
    current_port = base_port

    while len(available_ports) < count:
        if is_port_available(current_port):
            available_ports.append(current_port)
Byron Hsu's avatar
Byron Hsu committed
87
        current_port += random.randint(100, 1000)
88
89
90
91

    return available_ports


92
93
94
95
96
def cleanup_processes(processes: List[mp.Process]):
    for process in processes:
        process.terminate()


97
def main():
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
    mp.set_start_method("spawn")

    parser = argparse.ArgumentParser(
        description="Launch SGLang router and server processes"
    )

    ServerArgs.add_cli_args(parser)
    RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
    parser.add_argument(
        "--router-dp-worker-base-port",
        type=int,
        default=31000,
        help="Base port number for data parallel workers",
    )

    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)

    # Find available ports for workers
    worker_ports = find_available_ports(
        args.router_dp_worker_base_port, server_args.dp_size
    )

    # Start server processes
    server_processes = []

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    for i, worker_port in enumerate(worker_ports):
        logger.info(f"Launching DP server process {i} on port {worker_port}")
        proc = launch_server_process(server_args, worker_port, i)
        server_processes.append(proc)

    signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
    signal.signal(
        signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
    )
    signal.signal(
        signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
    )

    for port in worker_ports:
        if not wait_for_server_health(server_args.host, port):
            logger.error(f"Server on port {port} failed to become healthy")
            break

    logger.info("All servers are healthy. Starting router...")

    # Update router args with worker URLs
    router_args.worker_urls = [
        f"http://{server_args.host}:{port}" for port in worker_ports
    ]

    # Start the router
    router = launch_router(router_args)
154
155
156
157


if __name__ == "__main__":
    main()