serve.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import argparse
5
import signal
6
import time
7
8
9

import uvloop

10
import vllm
11
import vllm.envs as envs
12
from vllm.entrypoints.cli.types import CLISubcommand
13
14
15
16
17
18
from vllm.entrypoints.openai.api_server import (
    run_server,
    run_server_worker,
    setup_server,
)
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
19
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
20
21
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
22
from vllm.utils.argparse_utils import FlexibleArgumentParser
23
from vllm.utils.network_utils import get_tcp_uri
24
from vllm.utils.system_utils import decorate_logs, set_process_title
25
from vllm.v1.engine.core import EngineCoreProc
26
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
27
from vllm.v1.executor import Executor
28
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
29
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
30
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
31
32

logger = init_logger(__name__)
33

34
35
36
37
38
39
40
41
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM
completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified.

Search by using: `--help=<ConfigGroup>` to explore options by section (e.g.,
--help=ModelConfig, --help=Frontend)
  Use `--help=all` to show all available flags at once.
"""

42
43

class ServeSubcommand(CLISubcommand):
44
45
    """The `serve` subcommand for the vLLM CLI."""

46
    name = "serve"
47
48
49

    @staticmethod
    def cmd(args: argparse.Namespace) -> None:
50
        # If model is specified in CLI (as positional arg), it takes precedence
51
        if hasattr(args, "model_tag") and args.model_tag is not None:
52
            args.model = args.model_tag
53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        if args.headless:
            if args.api_server_count is not None and args.api_server_count > 0:
                raise ValueError(
                    f"--api-server-count={args.api_server_count} cannot be "
                    "used with --headless (no API servers are started in "
                    "headless mode)."
                )
            # Default to 0 in headless mode (no API servers)
            args.api_server_count = 0

        # Detect LB mode for defaulting api_server_count.
        # External LB: --data-parallel-external-lb or --data-parallel-rank
        # Hybrid LB: --data-parallel-hybrid-lb or --data-parallel-start-rank
        is_external_lb = (
            args.data_parallel_external_lb or args.data_parallel_rank is not None
        )
        is_hybrid_lb = (
            args.data_parallel_hybrid_lb or args.data_parallel_start_rank is not None
        )

        if is_external_lb and is_hybrid_lb:
            raise ValueError(
                "Cannot use both external and hybrid data parallel load "
                "balancing modes. External LB is enabled via "
                "--data-parallel-external-lb or --data-parallel-rank. "
                "Hybrid LB is enabled via --data-parallel-hybrid-lb or "
                "--data-parallel-start-rank. Use one mode or the other."
            )

        # Default api_server_count if not explicitly set.
        # - External LB: Leave as 1 (external LB handles distribution)
        # - Hybrid LB: Use local DP size (internal LB for local ranks only)
        # - Internal LB: Use full DP size
        if args.api_server_count is None:
            if is_external_lb:
                args.api_server_count = 1
            elif is_hybrid_lb:
                args.api_server_count = args.data_parallel_size_local or 1
                if args.api_server_count > 1:
                    logger.info(
                        "Defaulting api_server_count to data_parallel_size_local "
                        "(%d) for hybrid LB mode.",
                        args.api_server_count,
                    )
            else:
                args.api_server_count = args.data_parallel_size
                if args.api_server_count > 1:
                    logger.info(
                        "Defaulting api_server_count to data_parallel_size (%d).",
                        args.api_server_count,
                    )

        if args.api_server_count < 1:
107
            run_headless(args)
108
109
        elif args.api_server_count > 1:
            run_multi_api_server(args)
110
        else:
111
            # Single API server (this process).
112
            args.api_server_count = None
113
            uvloop.run(run_server(args))
114
115
116
117
118

    def validate(self, args: argparse.Namespace) -> None:
        validate_parsed_serve_args(args)

    def subparser_init(
119
120
        self, subparsers: argparse._SubParsersAction
    ) -> FlexibleArgumentParser:
121
        serve_parser = subparsers.add_parser(
122
123
124
125
126
            self.name,
            help="Launch a local OpenAI-compatible API server to serve LLM "
            "completions via HTTP.",
            description=DESCRIPTION,
            usage="vllm serve [model_tag] [options]",
127
        )
128

129
        serve_parser = make_arg_parser(serve_parser)
130
        serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
131
        return serve_parser
132
133


134
def cmd_init() -> list[CLISubcommand]:
135
    return [ServeSubcommand()]
136
137
138


def run_headless(args: argparse.Namespace):
139
140
    if args.api_server_count > 1:
        raise ValueError("api_server_count can't be set in headless mode")
141

142
    # Create the EngineConfig.
143
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
144
    usage_context = UsageContext.OPENAI_API_SERVER
145
146
147
    vllm_config = engine_args.create_engine_config(
        usage_context=usage_context, headless=True
    )
148

149
    if engine_args.data_parallel_hybrid_lb:
150
        raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode")
151

152
153
154
155
    parallel_config = vllm_config.parallel_config
    local_engine_count = parallel_config.data_parallel_size_local

    if local_engine_count <= 0:
156
        raise ValueError("data_parallel_size_local must be > 0 in headless mode")
157

158
    shutdown_requested = False
159

160
161
    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
162
        nonlocal shutdown_requested
163
        logger.debug("Received %d signal.", signum)
164
165
166
        if not shutdown_requested:
            shutdown_requested = True
            raise SystemExit
167
168
169
170

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    if parallel_config.node_rank_within_dp > 0:
        from vllm.version import __version__ as VLLM_VERSION

        # Run headless workers (for multi-node PP/TP).
        host = parallel_config.master_addr
        head_node_address = f"{host}:{parallel_config.master_port}"
        logger.info(
            "Launching vLLM (v%s) headless multiproc executor, "
            "with head node address %s for torch.distributed process group.",
            VLLM_VERSION,
            head_node_address,
        )

        executor = MultiprocExecutor(vllm_config, monitor_workers=False)
        executor.start_worker_monitor(inline=True)
        return

    host = parallel_config.data_parallel_master_ip
    port = parallel_config.data_parallel_rpc_port
    handshake_address = get_tcp_uri(host, port)

192
193
    logger.info(
        "Launching %d data parallel engine(s) in headless mode, "
194
195
196
197
        "with head node address %s.",
        local_engine_count,
        handshake_address,
    )
198
199
200
201
202

    # Create the engines.
    engine_manager = CoreEngineProcManager(
        target_fn=EngineCoreProc.run_engine_core,
        local_engine_count=local_engine_count,
203
        start_index=vllm_config.parallel_config.data_parallel_rank,
204
205
        local_start_index=0,
        vllm_config=vllm_config,
206
        local_client=False,
207
        handshake_address=handshake_address,
208
209
210
211
212
213
214
        executor_class=Executor.get_class(vllm_config),
        log_stats=not engine_args.disable_log_stats,
    )

    try:
        engine_manager.join_first()
    finally:
215
216
217
218
219
        timeout = None
        if shutdown_requested:
            timeout = vllm_config.shutdown_timeout
            logger.info("Waiting up to %d seconds for processes to exit", timeout)
        engine_manager.shutdown(timeout=timeout)
220
        logger.info("Shutting down.")
221
222
223
224


def run_multi_api_server(args: argparse.Namespace):
    assert not args.headless
225
    num_api_servers: int = args.api_server_count
226
227
228
229
230
    assert num_api_servers > 0

    if num_api_servers > 1:
        setup_multiprocess_prometheus()

231
232
233
234
235
236
237
238
239
240
241
242
243
    shutdown_requested = False

    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
        nonlocal shutdown_requested
        logger.debug("Received %d signal.", signum)
        if not shutdown_requested:
            shutdown_requested = True
            raise SystemExit

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

244
245
    listen_address, sock = setup_server(args)

246
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
247
248
249
    engine_args._api_process_count = num_api_servers
    engine_args._api_process_rank = -1

250
251
252
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

253
254
255
256
    if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
        raise ValueError(
            "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1"
        )
257

258
259
260
    executor_class = Executor.get_class(vllm_config)
    log_stats = not engine_args.disable_log_stats

261
    parallel_config = vllm_config.parallel_config
262
    dp_rank = parallel_config.data_parallel_rank
263
    assert parallel_config.local_engines_only or dp_rank == 0
264

265
    api_server_manager: APIServerProcessManager | None = None
266

267
268
269
270
    from vllm.v1.engine.utils import get_engine_zmq_addresses

    addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)

271
    with launch_core_engines(
272
        vllm_config, executor_class, log_stats, addresses, num_api_servers
273
    ) as (local_engine_manager, coordinator, addresses):
274
275
        # Construct common args for the APIServerProcessManager up-front.
        api_server_manager_kwargs = dict(
Rui Qiao's avatar
Rui Qiao committed
276
277
278
279
280
            target_server_fn=run_api_server_worker_proc,
            listen_address=listen_address,
            sock=sock,
            args=args,
            num_servers=num_api_servers,
281
282
283
            input_addresses=addresses.inputs,
            output_addresses=addresses.outputs,
            stats_update_address=coordinator.get_stats_publish_address()
284
285
286
            if coordinator
            else None,
        )
287

288
        # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
289
290
291
292
        # start of the API servers until the local engine is started
        # (after the launcher context manager exits),
        # since we get the front-end stats update address from the coordinator
        # via the handshake with the local engine.
293
        if dp_rank == 0 or not parallel_config.local_engines_only:
294
            # Start API servers using the manager.
295
            api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
296
297
298
299

    # Start API servers now if they weren't already started.
    if api_server_manager is None:
        api_server_manager_kwargs["stats_update_address"] = (
300
301
302
            addresses.frontend_stats_publish_address
        )
        api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
303
304

    # Wait for API servers
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    try:
        wait_for_completion_or_failure(
            api_server_manager=api_server_manager,
            engine_manager=local_engine_manager,
            coordinator=coordinator,
        )
    finally:
        timeout = shutdown_by = None
        if shutdown_requested:
            timeout = vllm_config.shutdown_timeout
            shutdown_by = time.monotonic() + timeout
            logger.info("Waiting up to %d seconds for processes to exit", timeout)

        def to_timeout(deadline: float | None) -> float | None:
            return (
                deadline if deadline is None else max(deadline - time.monotonic(), 0.0)
            )

        api_server_manager.shutdown(timeout=timeout)
        if local_engine_manager:
            local_engine_manager.shutdown(timeout=to_timeout(shutdown_by))
        if coordinator:
            coordinator.shutdown(timeout=to_timeout(shutdown_by))
328
329


330
331
332
def run_api_server_worker_proc(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
333
    """Entrypoint for individual API server worker processes."""
334
335
    client_config = client_config or {}
    server_index = client_config.get("client_index", 0)
336

337
338
    # Set process title and add process-specific prefix to stdout and stderr.
    set_process_title("APIServer", str(server_index))
339
    decorate_logs()
340
341

    uvloop.run(
342
343
        run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)
    )