launch_server.py 6.41 KB
Newer Older
1
import argparse
2
import asyncio
3
import copy
4
import logging
5
6
import multiprocessing as mp
import os
Byron Hsu's avatar
Byron Hsu committed
7
import random
8
9
10
11
12
13
import signal
import sys
import time
from typing import List

import requests
14
from setproctitle import setproctitle
15
16
from sglang_router.launch_router import RouterArgs, launch_router

Byron Hsu's avatar
Byron Hsu committed
17
from sglang.srt.server_args import ServerArgs
18
19
20
from sglang.srt.utils import is_port_available


21
22
23
24
25
def setup_logger():
    logger = logging.getLogger("router")
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter(
26
        "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
27
28
29
30
31
32
33
34
35
36
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    return logger


37
38
39
logger = setup_logger()


40
41
# Create new process group
def run_server(server_args, dp_rank):
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    """
    Note:

    1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
    This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.

    Terminal (PGID=100)
    └── Main Python Process (PGID=100)
        └── Server Process 1 (PGID=100)
            └── Scheduler 1
            └── Detokenizer 1
        └── Server Process 2 (PGID=100)
            └── Scheduler 2
            └── Detokenizer 2

    2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:

    Terminal (PGID=100)
    └── Main Python Process (PGID=200)
        └── Server Process 1 (PGID=300)
            └── Scheduler 1
            └── Detokenizer 1
        └── Server Process 2 (PGID=400)
            └── Scheduler 2
            └── Detokenizer 2
    """
    # create new process group
    os.setpgrp()

71
    setproctitle("sglang::server")
72
73
    # Set SGLANG_DP_RANK environment variable
    os.environ["SGLANG_DP_RANK"] = str(dp_rank)
74

75
76
77
78
79
80
81
82
83
    # Launch server in appropriate mode (HTTP or gRPC)
    if server_args.grpc_mode:
        from sglang.srt.entrypoints.grpc_server import serve_grpc

        asyncio.run(serve_grpc(server_args))
    else:
        from sglang.srt.entrypoints.http_server import launch_server

        launch_server(server_args)
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101


def launch_server_process(
    server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
    """Launch a single server process with the given args and port."""
    server_args = copy.deepcopy(server_args)
    server_args.port = worker_port
    server_args.base_gpu_id = dp_id * server_args.tp_size
    server_args.dp_size = 1

    proc = mp.Process(target=run_server, args=(server_args, dp_id))
    proc.start()
    return proc


def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
    """Wait for server to be healthy by checking /health endpoint."""
102
    start_time = time.perf_counter()
103
104
    url = f"http://{host}:{port}/health"

105
    while time.perf_counter() - start_time < timeout:
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        try:
            response = requests.get(url, timeout=5)
            if response.status_code == 200:
                return True
        except requests.exceptions.RequestException:
            pass
        time.sleep(1)
    return False


def find_available_ports(base_port: int, count: int) -> List[int]:
    """Find consecutive available ports starting from base_port."""
    available_ports = []
    current_port = base_port

    while len(available_ports) < count:
        if is_port_available(current_port):
            available_ports.append(current_port)
Byron Hsu's avatar
Byron Hsu committed
124
        current_port += random.randint(100, 1000)
125
126
127
128

    return available_ports


129
130
def cleanup_processes(processes: List[mp.Process]):
    for process in processes:
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        logger.info(f"Terminating process group {process.pid}")
        try:
            os.killpg(process.pid, signal.SIGTERM)
        except ProcessLookupError:
            # Process group may already be terminated
            pass

    # Wait for processes to terminate
    for process in processes:
        process.join(timeout=5)
        if process.is_alive():
            logger.warning(
                f"Process {process.pid} did not terminate gracefully, forcing kill"
            )
            try:
                os.killpg(process.pid, signal.SIGKILL)
            except ProcessLookupError:
                pass

    logger.info("All process groups terminated")
151
152


153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def main():
    # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
    mp.set_start_method("spawn")

    parser = argparse.ArgumentParser(
        description="Launch SGLang router and server processes"
    )

    ServerArgs.add_cli_args(parser)
    RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
    parser.add_argument(
        "--router-dp-worker-base-port",
        type=int,
        default=31000,
        help="Base port number for data parallel workers",
    )
169
    # No extra retry/CB flags here; RouterArgs.add_cli_args already defines them with router- prefix
170
171
172
173
174
175
176
177
178
179
180
181
182

    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)

    # Find available ports for workers
    worker_ports = find_available_ports(
        args.router_dp_worker_base_port, server_args.dp_size
    )

    # Start server processes
    server_processes = []

183
184
185
186
187
188
189
190
191
192
193
194
195
196
    for i, worker_port in enumerate(worker_ports):
        logger.info(f"Launching DP server process {i} on port {worker_port}")
        proc = launch_server_process(server_args, worker_port, i)
        server_processes.append(proc)

    signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
    signal.signal(
        signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
    )
    signal.signal(
        signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
    )

    # Update router args with worker URLs
197
198
    # Use grpc:// protocol if server is in gRPC mode, otherwise http://
    protocol = "grpc" if server_args.grpc_mode else "http"
199
    router_args.worker_urls = [
200
        f"{protocol}://{server_args.host}:{port}" for port in worker_ports
201
202
203
    ]

    # Start the router
204
205
206
207
208
209
    try:
        launch_router(router_args)
    except Exception as e:
        logger.error(f"Failed to start router: {e}")
        cleanup_processes(server_processes)
        sys.exit(1)
210
211
212
213


if __name__ == "__main__":
    main()