serve.py 11.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import argparse
5
import signal
6
7
8

import uvloop

9
import vllm
10
import vllm.envs as envs
11
from vllm.entrypoints.cli.types import CLISubcommand
12
13
14
15
16
17
from vllm.entrypoints.openai.api_server import (
    run_server,
    run_server_worker,
    setup_server,
)
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
18
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
19
20
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
21
from vllm.utils.argparse_utils import FlexibleArgumentParser
22
from vllm.utils.network_utils import get_tcp_uri
23
from vllm.utils.system_utils import decorate_logs, set_process_title
24
from vllm.v1.engine.core import EngineCoreProc
25
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
26
from vllm.v1.executor import Executor
27
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
28
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
29
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
30
31

logger = init_logger(__name__)
32

33
34
35
36
37
38
39
40
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM
completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified.

Search by using: `--help=<ConfigGroup>` to explore options by section (e.g.,
--help=ModelConfig, --help=Frontend)
  Use `--help=all` to show all available flags at once.
"""

41
42

class ServeSubcommand(CLISubcommand):
43
44
    """The `serve` subcommand for the vLLM CLI."""

45
    name = "serve"
46
47
48

    @staticmethod
    def cmd(args: argparse.Namespace) -> None:
49
        # If model is specified in CLI (as positional arg), it takes precedence
50
        if hasattr(args, "model_tag") and args.model_tag is not None:
51
            args.model = args.model_tag
52

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        if args.headless:
            if args.api_server_count is not None and args.api_server_count > 0:
                raise ValueError(
                    f"--api-server-count={args.api_server_count} cannot be "
                    "used with --headless (no API servers are started in "
                    "headless mode)."
                )
            # Default to 0 in headless mode (no API servers)
            args.api_server_count = 0

        # Detect LB mode for defaulting api_server_count.
        # External LB: --data-parallel-external-lb or --data-parallel-rank
        # Hybrid LB: --data-parallel-hybrid-lb or --data-parallel-start-rank
        is_external_lb = (
            args.data_parallel_external_lb or args.data_parallel_rank is not None
        )
        is_hybrid_lb = (
            args.data_parallel_hybrid_lb or args.data_parallel_start_rank is not None
        )

        if is_external_lb and is_hybrid_lb:
            raise ValueError(
                "Cannot use both external and hybrid data parallel load "
                "balancing modes. External LB is enabled via "
                "--data-parallel-external-lb or --data-parallel-rank. "
                "Hybrid LB is enabled via --data-parallel-hybrid-lb or "
                "--data-parallel-start-rank. Use one mode or the other."
            )

        # Default api_server_count if not explicitly set.
        # - External LB: Leave as 1 (external LB handles distribution)
        # - Hybrid LB: Use local DP size (internal LB for local ranks only)
        # - Internal LB: Use full DP size
        if args.api_server_count is None:
            if is_external_lb:
                args.api_server_count = 1
            elif is_hybrid_lb:
                args.api_server_count = args.data_parallel_size_local or 1
                if args.api_server_count > 1:
                    logger.info(
                        "Defaulting api_server_count to data_parallel_size_local "
                        "(%d) for hybrid LB mode.",
                        args.api_server_count,
                    )
            else:
                args.api_server_count = args.data_parallel_size
                if args.api_server_count > 1:
                    logger.info(
                        "Defaulting api_server_count to data_parallel_size (%d).",
                        args.api_server_count,
                    )

        if args.api_server_count < 1:
106
            run_headless(args)
107
108
        elif args.api_server_count > 1:
            run_multi_api_server(args)
109
        else:
110
            # Single API server (this process).
111
            args.api_server_count = None
112
            uvloop.run(run_server(args))
113
114
115
116
117

    def validate(self, args: argparse.Namespace) -> None:
        validate_parsed_serve_args(args)

    def subparser_init(
118
119
        self, subparsers: argparse._SubParsersAction
    ) -> FlexibleArgumentParser:
120
        serve_parser = subparsers.add_parser(
121
122
123
124
125
            self.name,
            help="Launch a local OpenAI-compatible API server to serve LLM "
            "completions via HTTP.",
            description=DESCRIPTION,
            usage="vllm serve [model_tag] [options]",
126
        )
127

128
        serve_parser = make_arg_parser(serve_parser)
129
        serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
130
        return serve_parser
131
132


133
def cmd_init() -> list[CLISubcommand]:
134
    return [ServeSubcommand()]
135
136
137


def run_headless(args: argparse.Namespace):
138
139
    if args.api_server_count > 1:
        raise ValueError("api_server_count can't be set in headless mode")
140

141
    # Create the EngineConfig.
142
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
143
    usage_context = UsageContext.OPENAI_API_SERVER
144
145
146
    vllm_config = engine_args.create_engine_config(
        usage_context=usage_context, headless=True
    )
147

148
    if engine_args.data_parallel_hybrid_lb:
149
        raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode")
150

151
152
153
154
    parallel_config = vllm_config.parallel_config
    local_engine_count = parallel_config.data_parallel_size_local

    if local_engine_count <= 0:
155
        raise ValueError("data_parallel_size_local must be > 0 in headless mode")
156

157
    shutdown_requested = False
158

159
160
    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
161
        nonlocal shutdown_requested
162
        logger.debug("Received %d signal.", signum)
163
164
165
        if not shutdown_requested:
            shutdown_requested = True
            raise SystemExit
166
167
168
169

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    if parallel_config.node_rank_within_dp > 0:
        from vllm.version import __version__ as VLLM_VERSION

        # Run headless workers (for multi-node PP/TP).
        host = parallel_config.master_addr
        head_node_address = f"{host}:{parallel_config.master_port}"
        logger.info(
            "Launching vLLM (v%s) headless multiproc executor, "
            "with head node address %s for torch.distributed process group.",
            VLLM_VERSION,
            head_node_address,
        )

        executor = MultiprocExecutor(vllm_config, monitor_workers=False)
        executor.start_worker_monitor(inline=True)
        return

    host = parallel_config.data_parallel_master_ip
    port = parallel_config.data_parallel_rpc_port
    handshake_address = get_tcp_uri(host, port)

191
192
    logger.info(
        "Launching %d data parallel engine(s) in headless mode, "
193
194
195
196
        "with head node address %s.",
        local_engine_count,
        handshake_address,
    )
197
198
199
200
201

    # Create the engines.
    engine_manager = CoreEngineProcManager(
        target_fn=EngineCoreProc.run_engine_core,
        local_engine_count=local_engine_count,
202
        start_index=vllm_config.parallel_config.data_parallel_rank,
203
204
        local_start_index=0,
        vllm_config=vllm_config,
205
        local_client=False,
206
        handshake_address=handshake_address,
207
208
209
210
211
212
213
214
215
        executor_class=Executor.get_class(vllm_config),
        log_stats=not engine_args.disable_log_stats,
    )

    try:
        engine_manager.join_first()
    finally:
        logger.info("Shutting down.")
        engine_manager.close()
216
217
218
219


def run_multi_api_server(args: argparse.Namespace):
    assert not args.headless
220
    num_api_servers: int = args.api_server_count
221
222
223
224
225
226
227
    assert num_api_servers > 0

    if num_api_servers > 1:
        setup_multiprocess_prometheus()

    listen_address, sock = setup_server(args)

228
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
229
230
231
    engine_args._api_process_count = num_api_servers
    engine_args._api_process_rank = -1

232
233
234
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

235
236
237
238
    if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
        raise ValueError(
            "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1"
        )
239

240
241
242
    executor_class = Executor.get_class(vllm_config)
    log_stats = not engine_args.disable_log_stats

243
    parallel_config = vllm_config.parallel_config
244
    dp_rank = parallel_config.data_parallel_rank
245
    assert parallel_config.local_engines_only or dp_rank == 0
246

247
    api_server_manager: APIServerProcessManager | None = None
248

249
250
251
252
    from vllm.v1.engine.utils import get_engine_zmq_addresses

    addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)

253
    with launch_core_engines(
254
        vllm_config, executor_class, log_stats, addresses, num_api_servers
255
    ) as (local_engine_manager, coordinator, addresses):
256
257
        # Construct common args for the APIServerProcessManager up-front.
        api_server_manager_kwargs = dict(
Rui Qiao's avatar
Rui Qiao committed
258
259
260
261
262
            target_server_fn=run_api_server_worker_proc,
            listen_address=listen_address,
            sock=sock,
            args=args,
            num_servers=num_api_servers,
263
264
265
            input_addresses=addresses.inputs,
            output_addresses=addresses.outputs,
            stats_update_address=coordinator.get_stats_publish_address()
266
267
268
            if coordinator
            else None,
        )
269

270
        # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
271
272
273
274
        # start of the API servers until the local engine is started
        # (after the launcher context manager exits),
        # since we get the front-end stats update address from the coordinator
        # via the handshake with the local engine.
275
        if dp_rank == 0 or not parallel_config.local_engines_only:
276
            # Start API servers using the manager.
277
            api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
278
279
280
281

    # Start API servers now if they weren't already started.
    if api_server_manager is None:
        api_server_manager_kwargs["stats_update_address"] = (
282
283
284
            addresses.frontend_stats_publish_address
        )
        api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
285
286

    # Wait for API servers
287
288
289
290
291
    wait_for_completion_or_failure(
        api_server_manager=api_server_manager,
        engine_manager=local_engine_manager,
        coordinator=coordinator,
    )
292
293


294
295
296
def run_api_server_worker_proc(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
297
    """Entrypoint for individual API server worker processes."""
298
299
    client_config = client_config or {}
    server_index = client_config.get("client_index", 0)
300

301
302
    # Set process title and add process-specific prefix to stdout and stderr.
    set_process_title("APIServer", str(server_index))
303
    decorate_logs()
304
305

    uvloop.run(
306
307
        run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)
    )