serve.py 11.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import argparse
5
import signal
6
7
8

import uvloop

9
import vllm
10
import vllm.envs as envs
11
from vllm.entrypoints.cli.types import CLISubcommand
12
13
14
15
16
17
from vllm.entrypoints.openai.api_server import (
    run_server,
    run_server_worker,
    setup_server,
)
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
18
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
19
20
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
21
from vllm.utils.argparse_utils import FlexibleArgumentParser
22
from vllm.utils.network_utils import get_tcp_uri
23
from vllm.utils.system_utils import decorate_logs, set_process_title
24
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
25
from vllm.v1.executor import Executor
26
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
27
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
28
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
29
30

logger = init_logger(__name__)
31

32
33
34
35
36
37
38
39
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM
completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified.

Search by using: `--help=<ConfigGroup>` to explore options by section (e.g.,
--help=ModelConfig, --help=Frontend)
  Use `--help=all` to show all available flags at once.
"""

40
41

class ServeSubcommand(CLISubcommand):
42
43
    """The `serve` subcommand for the vLLM CLI."""

44
    name = "serve"
45
46
47

    @staticmethod
    def cmd(args: argparse.Namespace) -> None:
48
        # If model is specified in CLI (as positional arg), it takes precedence
49
        if hasattr(args, "model_tag") and args.model_tag is not None:
50
            args.model = args.model_tag
51

52
53
54
55
56
57
        if getattr(args, "grpc", False):
            from vllm.entrypoints.grpc_server import serve_grpc

            uvloop.run(serve_grpc(args))
            return

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        if args.headless:
            if args.api_server_count is not None and args.api_server_count > 0:
                raise ValueError(
                    f"--api-server-count={args.api_server_count} cannot be "
                    "used with --headless (no API servers are started in "
                    "headless mode)."
                )
            # Default to 0 in headless mode (no API servers)
            args.api_server_count = 0

        # Detect LB mode for defaulting api_server_count.
        # External LB: --data-parallel-external-lb or --data-parallel-rank
        # Hybrid LB: --data-parallel-hybrid-lb or --data-parallel-start-rank
        is_external_lb = (
            args.data_parallel_external_lb or args.data_parallel_rank is not None
        )
        is_hybrid_lb = (
            args.data_parallel_hybrid_lb or args.data_parallel_start_rank is not None
        )

        if is_external_lb and is_hybrid_lb:
            raise ValueError(
                "Cannot use both external and hybrid data parallel load "
                "balancing modes. External LB is enabled via "
                "--data-parallel-external-lb or --data-parallel-rank. "
                "Hybrid LB is enabled via --data-parallel-hybrid-lb or "
                "--data-parallel-start-rank. Use one mode or the other."
            )

        # Default api_server_count if not explicitly set.
        # - External LB: Leave as 1 (external LB handles distribution)
        # - Hybrid LB: Use local DP size (internal LB for local ranks only)
        # - Internal LB: Use full DP size
        if args.api_server_count is None:
            if is_external_lb:
                args.api_server_count = 1
            elif is_hybrid_lb:
                args.api_server_count = args.data_parallel_size_local or 1
                if args.api_server_count > 1:
                    logger.info(
                        "Defaulting api_server_count to data_parallel_size_local "
                        "(%d) for hybrid LB mode.",
                        args.api_server_count,
                    )
            else:
                args.api_server_count = args.data_parallel_size
                if args.api_server_count > 1:
                    logger.info(
                        "Defaulting api_server_count to data_parallel_size (%d).",
                        args.api_server_count,
                    )

        if args.api_server_count < 1:
111
            run_headless(args)
112
113
        elif args.api_server_count > 1:
            run_multi_api_server(args)
114
        else:
115
            # Single API server (this process).
116
            args.api_server_count = None
117
            uvloop.run(run_server(args))
118
119
120
121
122

    def validate(self, args: argparse.Namespace) -> None:
        validate_parsed_serve_args(args)

    def subparser_init(
123
124
        self, subparsers: argparse._SubParsersAction
    ) -> FlexibleArgumentParser:
125
        serve_parser = subparsers.add_parser(
126
127
128
129
130
            self.name,
            help="Launch a local OpenAI-compatible API server to serve LLM "
            "completions via HTTP.",
            description=DESCRIPTION,
            usage="vllm serve [model_tag] [options]",
131
        )
132

133
        serve_parser = make_arg_parser(serve_parser)
134
135
136
137
138
139
140
        serve_parser.add_argument(
            "--grpc",
            action="store_true",
            default=False,
            help="Launch a gRPC server instead of the HTTP OpenAI-compatible "
            "server. Requires: pip install vllm[grpc].",
        )
141
        serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
142
        return serve_parser
143
144


145
def cmd_init() -> list[CLISubcommand]:
146
    return [ServeSubcommand()]
147
148
149


def run_headless(args: argparse.Namespace):
150
151
    if args.api_server_count > 1:
        raise ValueError("api_server_count can't be set in headless mode")
152

153
    # Create the EngineConfig.
154
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
155
    usage_context = UsageContext.OPENAI_API_SERVER
156
157
158
    vllm_config = engine_args.create_engine_config(
        usage_context=usage_context, headless=True
    )
159

160
    if engine_args.data_parallel_hybrid_lb:
161
        raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode")
162

163
164
165
166
    parallel_config = vllm_config.parallel_config
    local_engine_count = parallel_config.data_parallel_size_local

    if local_engine_count <= 0:
167
        raise ValueError("data_parallel_size_local must be > 0 in headless mode")
168

169
    shutdown_requested = False
170

171
172
    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
173
        nonlocal shutdown_requested
174
        logger.debug("Received %d signal.", signum)
175
176
177
        if not shutdown_requested:
            shutdown_requested = True
            raise SystemExit
178
179
180
181

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    if parallel_config.node_rank_within_dp > 0:
        from vllm.version import __version__ as VLLM_VERSION

        # Run headless workers (for multi-node PP/TP).
        host = parallel_config.master_addr
        head_node_address = f"{host}:{parallel_config.master_port}"
        logger.info(
            "Launching vLLM (v%s) headless multiproc executor, "
            "with head node address %s for torch.distributed process group.",
            VLLM_VERSION,
            head_node_address,
        )

        executor = MultiprocExecutor(vllm_config, monitor_workers=False)
        executor.start_worker_monitor(inline=True)
        return

    host = parallel_config.data_parallel_master_ip
    port = parallel_config.data_parallel_rpc_port
    handshake_address = get_tcp_uri(host, port)

203
204
    logger.info(
        "Launching %d data parallel engine(s) in headless mode, "
205
206
207
208
        "with head node address %s.",
        local_engine_count,
        handshake_address,
    )
209
210
211
212

    # Create the engines.
    engine_manager = CoreEngineProcManager(
        local_engine_count=local_engine_count,
213
        start_index=vllm_config.parallel_config.data_parallel_rank,
214
215
        local_start_index=0,
        vllm_config=vllm_config,
216
        local_client=False,
217
        handshake_address=handshake_address,
218
219
220
221
222
223
224
225
        executor_class=Executor.get_class(vllm_config),
        log_stats=not engine_args.disable_log_stats,
    )

    try:
        engine_manager.join_first()
    finally:
        logger.info("Shutting down.")
226
        engine_manager.close()
227
228
229
230


def run_multi_api_server(args: argparse.Namespace):
    assert not args.headless
231
    num_api_servers: int = args.api_server_count
232
233
234
235
236
237
238
    assert num_api_servers > 0

    if num_api_servers > 1:
        setup_multiprocess_prometheus()

    listen_address, sock = setup_server(args)

239
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
240
241
242
    engine_args._api_process_count = num_api_servers
    engine_args._api_process_rank = -1

243
244
245
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

246
247
248
249
    if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
        raise ValueError(
            "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1"
        )
250

251
252
253
    executor_class = Executor.get_class(vllm_config)
    log_stats = not engine_args.disable_log_stats

254
    parallel_config = vllm_config.parallel_config
255
    dp_rank = parallel_config.data_parallel_rank
256
    assert parallel_config.local_engines_only or dp_rank == 0
257

258
    api_server_manager: APIServerProcessManager | None = None
259

260
261
262
263
    from vllm.v1.engine.utils import get_engine_zmq_addresses

    addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)

264
    with launch_core_engines(
265
        vllm_config, executor_class, log_stats, addresses, num_api_servers
266
    ) as (local_engine_manager, coordinator, addresses):
267
268
        # Construct common args for the APIServerProcessManager up-front.
        api_server_manager_kwargs = dict(
Rui Qiao's avatar
Rui Qiao committed
269
270
271
272
273
            target_server_fn=run_api_server_worker_proc,
            listen_address=listen_address,
            sock=sock,
            args=args,
            num_servers=num_api_servers,
274
275
276
            input_addresses=addresses.inputs,
            output_addresses=addresses.outputs,
            stats_update_address=coordinator.get_stats_publish_address()
277
278
279
            if coordinator
            else None,
        )
280

281
        # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
282
283
284
285
        # start of the API servers until the local engine is started
        # (after the launcher context manager exits),
        # since we get the front-end stats update address from the coordinator
        # via the handshake with the local engine.
286
        if dp_rank == 0 or not parallel_config.local_engines_only:
287
            # Start API servers using the manager.
288
            api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
289
290
291
292

    # Start API servers now if they weren't already started.
    if api_server_manager is None:
        api_server_manager_kwargs["stats_update_address"] = (
293
294
295
            addresses.frontend_stats_publish_address
        )
        api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
296
297

    # Wait for API servers
298
299
300
301
302
    wait_for_completion_or_failure(
        api_server_manager=api_server_manager,
        engine_manager=local_engine_manager,
        coordinator=coordinator,
    )
303
304


305
306
307
def run_api_server_worker_proc(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
308
    """Entrypoint for individual API server worker processes."""
309
310
    client_config = client_config or {}
    server_index = client_config.get("client_index", 0)
311

312
313
    # Set process title and add process-specific prefix to stdout and stderr.
    set_process_title("APIServer", str(server_index))
314
    decorate_logs()
315
316

    uvloop.run(
317
318
        run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)
    )