launch_server.py 5.35 KB
Newer Older
1
2
3
4
import argparse
import copy
import multiprocessing as mp
import os
Byron Hsu's avatar
Byron Hsu committed
5
import random
6
7
8
9
10
11
12
13
14
import signal
import sys
import time
from typing import List

import requests
from sglang_router.launch_router import RouterArgs, launch_router

from sglang.srt.server import launch_server
Byron Hsu's avatar
Byron Hsu committed
15
from sglang.srt.server_args import ServerArgs
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback


# Create new process group
def run_server(server_args, dp_rank):
    os.setpgrp()  # Create new process group

    # Set DP_RANK environment variable
    os.environ["DP_RANK"] = str(dp_rank)

    launch_server(server_args)


def launch_server_process(
    server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
    """Launch a single server process with the given args and port."""
    server_args = copy.deepcopy(server_args)
    server_args.port = worker_port
    server_args.base_gpu_id = dp_id * server_args.tp_size
    server_args.dp_size = 1

    proc = mp.Process(target=run_server, args=(server_args, dp_id))
    proc.start()
    return proc


def cleanup_processes(processes: List[mp.Process]):
    """Clean up all processes using process groups."""
    print("\nCleaning up processes...")
    for proc in processes:
        if proc.is_alive():
            try:
                # Kill the entire process group
                os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
                # Give processes some time to terminate gracefully
                proc.join(timeout=3)
                # If process is still alive, force kill
                if proc.is_alive():
                    os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
            except ProcessLookupError:
                pass  # Process already terminated


def setup_signal_handlers(cleanup_func):
    """Setup handlers for various termination signals."""

    def signal_handler(signum, frame):
        cleanup_func()
        sys.exit(1)

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)
    if hasattr(signal, "SIGQUIT"):
        signal.signal(signal.SIGQUIT, signal_handler)


def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
    """Wait for server to be healthy by checking /health endpoint."""
    start_time = time.time()
    url = f"http://{host}:{port}/health"

    while time.time() - start_time < timeout:
        try:
            response = requests.get(url, timeout=5)
            if response.status_code == 200:
                return True
        except requests.exceptions.RequestException:
            pass
        time.sleep(1)
    return False


def find_available_ports(base_port: int, count: int) -> List[int]:
    """Find consecutive available ports starting from base_port."""
    available_ports = []
    current_port = base_port

    while len(available_ports) < count:
        if is_port_available(current_port):
            available_ports.append(current_port)
Byron Hsu's avatar
Byron Hsu committed
98
        current_port += random.randint(100, 1000)
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

    return available_ports


def main():
    # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
    mp.set_start_method("spawn")

    parser = argparse.ArgumentParser(
        description="Launch SGLang router and server processes"
    )

    ServerArgs.add_cli_args(parser)
    RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
    parser.add_argument(
        "--router-dp-worker-base-port",
        type=int,
        default=31000,
        help="Base port number for data parallel workers",
    )

    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)

    # Find available ports for workers
    worker_ports = find_available_ports(
        args.router_dp_worker_base_port, server_args.dp_size
    )

    # Start server processes
    server_processes = []

    try:
        # Launch server processes
        for i, worker_port in enumerate(worker_ports):
            proc = launch_server_process(server_args, worker_port, i)
            server_processes.append(proc)

        # Setup cleanup handler
        setup_signal_handlers(lambda: cleanup_processes(server_processes))

        # Wait for all servers to be healthy
        all_healthy = True
        for port in worker_ports:
            if not wait_for_server_health(server_args.host, port):
                print(f"Server on port {port} failed to become healthy")
                all_healthy = False
                break

        if not all_healthy:
            print("Not all servers are healthy. Shutting down...")
            cleanup_processes(server_processes)
            sys.exit(1)

        print("All servers are healthy. Starting router...")

        # Update router args with worker URLs
        router_args.worker_urls = [
            f"http://{server_args.host}:{port}" for port in worker_ports
        ]

        # Start the router
        router = launch_router(router_args)

        if router is None:
            print("Failed to start router. Shutting down...")
            cleanup_processes(server_processes)
            sys.exit(1)

    except KeyboardInterrupt:
        print("\nReceived shutdown signal...")
    except Exception as e:
        print(f"Error occurred: {e}")
        print(get_exception_traceback())
    finally:
        cleanup_processes(server_processes)


if __name__ == "__main__":
    main()