serve.py 13.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import argparse
5
import signal
6
import time
7
8
9

import uvloop

10
import vllm
11
import vllm.envs as envs
12
from vllm.entrypoints.cli.types import CLISubcommand
13
14
15
16
17
18
from vllm.entrypoints.openai.api_server import (
    run_server,
    run_server_worker,
    setup_server,
)
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
19
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
20
21
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
22
from vllm.utils.argparse_utils import FlexibleArgumentParser
23
from vllm.utils.network_utils import get_tcp_uri
24
from vllm.utils.system_utils import decorate_logs, set_process_title
25
from vllm.v1.engine.core import EngineCoreProc
26
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
27
from vllm.v1.executor import Executor
28
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
29
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
30
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
31
32

logger = init_logger(__name__)
33

34
35
36
37
38
39
40
41
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM
completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified.

Search by using: `--help=<ConfigGroup>` to explore options by section (e.g.,
--help=ModelConfig, --help=Frontend)
  Use `--help=all` to show all available flags at once.
"""

42
43

class ServeSubcommand(CLISubcommand):
44
45
    """The `serve` subcommand for the vLLM CLI."""

46
    name = "serve"
47
48
49

    @staticmethod
    def cmd(args: argparse.Namespace) -> None:
50
        # If model is specified in CLI (as positional arg), it takes precedence
51
        if hasattr(args, "model_tag") and args.model_tag is not None:
52
            args.model = args.model_tag
53

54
55
56
57
58
59
        if getattr(args, "grpc", False):
            from vllm.entrypoints.grpc_server import serve_grpc

            uvloop.run(serve_grpc(args))
            return

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        if args.headless:
            if args.api_server_count is not None and args.api_server_count > 0:
                raise ValueError(
                    f"--api-server-count={args.api_server_count} cannot be "
                    "used with --headless (no API servers are started in "
                    "headless mode)."
                )
            # Default to 0 in headless mode (no API servers)
            args.api_server_count = 0

        # Detect LB mode for defaulting api_server_count.
        # External LB: --data-parallel-external-lb or --data-parallel-rank
        # Hybrid LB: --data-parallel-hybrid-lb or --data-parallel-start-rank
        is_external_lb = (
            args.data_parallel_external_lb or args.data_parallel_rank is not None
        )
        is_hybrid_lb = (
            args.data_parallel_hybrid_lb or args.data_parallel_start_rank is not None
        )

        if is_external_lb and is_hybrid_lb:
            raise ValueError(
                "Cannot use both external and hybrid data parallel load "
                "balancing modes. External LB is enabled via "
                "--data-parallel-external-lb or --data-parallel-rank. "
                "Hybrid LB is enabled via --data-parallel-hybrid-lb or "
                "--data-parallel-start-rank. Use one mode or the other."
            )

        # Default api_server_count if not explicitly set.
        # - External LB: Leave as 1 (external LB handles distribution)
        # - Hybrid LB: Use local DP size (internal LB for local ranks only)
        # - Internal LB: Use full DP size
        if args.api_server_count is None:
            if is_external_lb:
                args.api_server_count = 1
            elif is_hybrid_lb:
                args.api_server_count = args.data_parallel_size_local or 1
                if args.api_server_count > 1:
                    logger.info(
                        "Defaulting api_server_count to data_parallel_size_local "
                        "(%d) for hybrid LB mode.",
                        args.api_server_count,
                    )
            else:
                args.api_server_count = args.data_parallel_size
                if args.api_server_count > 1:
                    logger.info(
                        "Defaulting api_server_count to data_parallel_size (%d).",
                        args.api_server_count,
                    )

        if args.api_server_count < 1:
113
            run_headless(args)
114
115
        elif args.api_server_count > 1:
            run_multi_api_server(args)
116
        else:
117
            # Single API server (this process).
118
            args.api_server_count = None
119
            uvloop.run(run_server(args))
120
121
122
123
124

    def validate(self, args: argparse.Namespace) -> None:
        validate_parsed_serve_args(args)

    def subparser_init(
125
126
        self, subparsers: argparse._SubParsersAction
    ) -> FlexibleArgumentParser:
127
        serve_parser = subparsers.add_parser(
128
129
130
131
132
            self.name,
            help="Launch a local OpenAI-compatible API server to serve LLM "
            "completions via HTTP.",
            description=DESCRIPTION,
            usage="vllm serve [model_tag] [options]",
133
        )
134

135
        serve_parser = make_arg_parser(serve_parser)
136
137
138
139
140
141
142
        serve_parser.add_argument(
            "--grpc",
            action="store_true",
            default=False,
            help="Launch a gRPC server instead of the HTTP OpenAI-compatible "
            "server. Requires: pip install vllm[grpc].",
        )
143
        serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
144
        return serve_parser
145
146


147
def cmd_init() -> list[CLISubcommand]:
148
    return [ServeSubcommand()]
149
150
151


def run_headless(args: argparse.Namespace):
152
153
    if args.api_server_count > 1:
        raise ValueError("api_server_count can't be set in headless mode")
154

155
    # Create the EngineConfig.
156
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
157
    usage_context = UsageContext.OPENAI_API_SERVER
158
159
160
    vllm_config = engine_args.create_engine_config(
        usage_context=usage_context, headless=True
    )
161

162
    if engine_args.data_parallel_hybrid_lb:
163
        raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode")
164

165
166
167
168
    parallel_config = vllm_config.parallel_config
    local_engine_count = parallel_config.data_parallel_size_local

    if local_engine_count <= 0:
169
        raise ValueError("data_parallel_size_local must be > 0 in headless mode")
170

171
    shutdown_requested = False
172

173
174
    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
175
        nonlocal shutdown_requested
176
        logger.debug("Received %d signal.", signum)
177
178
179
        if not shutdown_requested:
            shutdown_requested = True
            raise SystemExit
180
181
182
183

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    if parallel_config.node_rank_within_dp > 0:
        from vllm.version import __version__ as VLLM_VERSION

        # Run headless workers (for multi-node PP/TP).
        host = parallel_config.master_addr
        head_node_address = f"{host}:{parallel_config.master_port}"
        logger.info(
            "Launching vLLM (v%s) headless multiproc executor, "
            "with head node address %s for torch.distributed process group.",
            VLLM_VERSION,
            head_node_address,
        )

        executor = MultiprocExecutor(vllm_config, monitor_workers=False)
        executor.start_worker_monitor(inline=True)
        return

    host = parallel_config.data_parallel_master_ip
    port = parallel_config.data_parallel_rpc_port
    handshake_address = get_tcp_uri(host, port)

205
206
    logger.info(
        "Launching %d data parallel engine(s) in headless mode, "
207
208
209
210
        "with head node address %s.",
        local_engine_count,
        handshake_address,
    )
211
212
213
214
215

    # Create the engines.
    engine_manager = CoreEngineProcManager(
        target_fn=EngineCoreProc.run_engine_core,
        local_engine_count=local_engine_count,
216
        start_index=vllm_config.parallel_config.data_parallel_rank,
217
218
        local_start_index=0,
        vllm_config=vllm_config,
219
        local_client=False,
220
        handshake_address=handshake_address,
221
222
223
224
225
226
227
        executor_class=Executor.get_class(vllm_config),
        log_stats=not engine_args.disable_log_stats,
    )

    try:
        engine_manager.join_first()
    finally:
228
229
230
231
232
        timeout = None
        if shutdown_requested:
            timeout = vllm_config.shutdown_timeout
            logger.info("Waiting up to %d seconds for processes to exit", timeout)
        engine_manager.shutdown(timeout=timeout)
233
        logger.info("Shutting down.")
234
235
236
237


def run_multi_api_server(args: argparse.Namespace):
    assert not args.headless
238
    num_api_servers: int = args.api_server_count
239
240
241
242
243
    assert num_api_servers > 0

    if num_api_servers > 1:
        setup_multiprocess_prometheus()

244
245
246
247
248
249
250
251
252
253
254
255
256
    shutdown_requested = False

    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
        nonlocal shutdown_requested
        logger.debug("Received %d signal.", signum)
        if not shutdown_requested:
            shutdown_requested = True
            raise SystemExit

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

257
258
    listen_address, sock = setup_server(args)

259
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
260
261
262
    engine_args._api_process_count = num_api_servers
    engine_args._api_process_rank = -1

263
264
265
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

266
267
268
269
    if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
        raise ValueError(
            "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1"
        )
270

271
272
273
    executor_class = Executor.get_class(vllm_config)
    log_stats = not engine_args.disable_log_stats

274
    parallel_config = vllm_config.parallel_config
275
    dp_rank = parallel_config.data_parallel_rank
276
    assert parallel_config.local_engines_only or dp_rank == 0
277

278
    api_server_manager: APIServerProcessManager | None = None
279

280
281
282
283
    from vllm.v1.engine.utils import get_engine_zmq_addresses

    addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)

284
    with launch_core_engines(
285
        vllm_config, executor_class, log_stats, addresses, num_api_servers
286
    ) as (local_engine_manager, coordinator, addresses):
287
288
        # Construct common args for the APIServerProcessManager up-front.
        api_server_manager_kwargs = dict(
Rui Qiao's avatar
Rui Qiao committed
289
290
291
292
293
            target_server_fn=run_api_server_worker_proc,
            listen_address=listen_address,
            sock=sock,
            args=args,
            num_servers=num_api_servers,
294
295
296
            input_addresses=addresses.inputs,
            output_addresses=addresses.outputs,
            stats_update_address=coordinator.get_stats_publish_address()
297
298
299
            if coordinator
            else None,
        )
300

301
        # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
302
303
304
305
        # start of the API servers until the local engine is started
        # (after the launcher context manager exits),
        # since we get the front-end stats update address from the coordinator
        # via the handshake with the local engine.
306
        if dp_rank == 0 or not parallel_config.local_engines_only:
307
            # Start API servers using the manager.
308
            api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
309
310
311
312

    # Start API servers now if they weren't already started.
    if api_server_manager is None:
        api_server_manager_kwargs["stats_update_address"] = (
313
314
315
            addresses.frontend_stats_publish_address
        )
        api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
316
317

    # Wait for API servers
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    try:
        wait_for_completion_or_failure(
            api_server_manager=api_server_manager,
            engine_manager=local_engine_manager,
            coordinator=coordinator,
        )
    finally:
        timeout = shutdown_by = None
        if shutdown_requested:
            timeout = vllm_config.shutdown_timeout
            shutdown_by = time.monotonic() + timeout
            logger.info("Waiting up to %d seconds for processes to exit", timeout)

        def to_timeout(deadline: float | None) -> float | None:
            return (
                deadline if deadline is None else max(deadline - time.monotonic(), 0.0)
            )

        api_server_manager.shutdown(timeout=timeout)
        if local_engine_manager:
            local_engine_manager.shutdown(timeout=to_timeout(shutdown_by))
        if coordinator:
            coordinator.shutdown(timeout=to_timeout(shutdown_by))
341
342


343
344
345
def run_api_server_worker_proc(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
346
    """Entrypoint for individual API server worker processes."""
347
348
    client_config = client_config or {}
    server_index = client_config.get("client_index", 0)
349

350
351
    # Set process title and add process-specific prefix to stdout and stderr.
    set_process_title("APIServer", str(server_index))
352
    decorate_logs()
353
354

    uvloop.run(
355
356
        run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)
    )