launch_server.py 5.85 KB
Newer Older
1
2
import argparse
import copy
3
import logging
4
5
import multiprocessing as mp
import os
Byron Hsu's avatar
Byron Hsu committed
6
import random
7
8
9
10
11
12
13
14
15
import signal
import sys
import time
from typing import List

import requests
from sglang_router.launch_router import RouterArgs, launch_router

from sglang.srt.server import launch_server
Byron Hsu's avatar
Byron Hsu committed
16
from sglang.srt.server_args import ServerArgs
17
18
19
20
from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def setup_logger():
    logger = logging.getLogger("router")
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter(
        "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    return logger


37
38
39
40
# Create new process group
def run_server(server_args, dp_rank):
    os.setpgrp()  # Create new process group

41
42
    # Set SGLANG_DP_RANK environment variable
    os.environ["SGLANG_DP_RANK"] = str(dp_rank)
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

    launch_server(server_args)


def launch_server_process(
    server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
    """Launch a single server process with the given args and port."""
    server_args = copy.deepcopy(server_args)
    server_args.port = worker_port
    server_args.base_gpu_id = dp_id * server_args.tp_size
    server_args.dp_size = 1

    proc = mp.Process(target=run_server, args=(server_args, dp_id))
    proc.start()
    return proc


def cleanup_processes(processes: List[mp.Process]):
62
63
    logger = logging.getLogger("router")
    logger.info("Cleaning up processes...")
64
65
66
67
68
69
    for proc in processes:
        if proc.is_alive():
            try:
                os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
                proc.join(timeout=3)
                if proc.is_alive():
70
71
72
                    logger.warning(
                        f"Process {proc.pid} did not terminate gracefully, force killing..."
                    )
73
74
                    os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
            except ProcessLookupError:
75
                pass
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114


def setup_signal_handlers(cleanup_func):
    """Setup handlers for various termination signals."""

    def signal_handler(signum, frame):
        cleanup_func()
        sys.exit(1)

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)
    if hasattr(signal, "SIGQUIT"):
        signal.signal(signal.SIGQUIT, signal_handler)


def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
    """Wait for server to be healthy by checking /health endpoint."""
    start_time = time.time()
    url = f"http://{host}:{port}/health"

    while time.time() - start_time < timeout:
        try:
            response = requests.get(url, timeout=5)
            if response.status_code == 200:
                return True
        except requests.exceptions.RequestException:
            pass
        time.sleep(1)
    return False


def find_available_ports(base_port: int, count: int) -> List[int]:
    """Find consecutive available ports starting from base_port."""
    available_ports = []
    current_port = base_port

    while len(available_ports) < count:
        if is_port_available(current_port):
            available_ports.append(current_port)
Byron Hsu's avatar
Byron Hsu committed
115
        current_port += random.randint(100, 1000)
116
117
118
119
120

    return available_ports


def main():
121
122
    logger = setup_logger()

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
    mp.set_start_method("spawn")

    parser = argparse.ArgumentParser(
        description="Launch SGLang router and server processes"
    )

    ServerArgs.add_cli_args(parser)
    RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
    parser.add_argument(
        "--router-dp-worker-base-port",
        type=int,
        default=31000,
        help="Base port number for data parallel workers",
    )

    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)

    # Find available ports for workers
    worker_ports = find_available_ports(
        args.router_dp_worker_base_port, server_args.dp_size
    )

    # Start server processes
    server_processes = []

    try:
        for i, worker_port in enumerate(worker_ports):
153
            logger.info(f"Launching DP server process {i} on port {worker_port}")
154
155
156
157
158
159
160
161
            proc = launch_server_process(server_args, worker_port, i)
            server_processes.append(proc)

        # Setup cleanup handler
        setup_signal_handlers(lambda: cleanup_processes(server_processes))

        # Wait for all servers to be healthy
        all_healthy = True
162

163
164
        for port in worker_ports:
            if not wait_for_server_health(server_args.host, port):
165
                logger.error(f"Server on port {port} failed to become healthy")
166
167
168
169
                all_healthy = False
                break

        if not all_healthy:
170
            logger.error("Not all servers are healthy. Shutting down...")
171
172
173
            cleanup_processes(server_processes)
            sys.exit(1)

174
        logger.info("All servers are healthy. Starting router...")
175
176
177
178
179
180
181
182
183
184

        # Update router args with worker URLs
        router_args.worker_urls = [
            f"http://{server_args.host}:{port}" for port in worker_ports
        ]

        # Start the router
        router = launch_router(router_args)

        if router is None:
185
            logger.error("Failed to start router. Shutting down...")
186
187
188
189
            cleanup_processes(server_processes)
            sys.exit(1)

    except KeyboardInterrupt:
190
        logger.info("Received shutdown signal...")
191
    except Exception as e:
192
193
        logger.error(f"Error occurred: {e}")
        logger.error(get_exception_traceback())
194
    finally:
195
        logger.info("Cleaning up processes...")
196
197
198
199
200
        cleanup_processes(server_processes)


if __name__ == "__main__":
    main()