serve.py 11.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import argparse
5
import signal
6
import time
7
8
9

import uvloop

10
import vllm
11
import vllm.envs as envs
12
from vllm.entrypoints.cli.types import CLISubcommand
13
from vllm.entrypoints.openai.api_server import run_server, setup_server
14
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
15
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
16
17
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
18
from vllm.utils.argparse_utils import FlexibleArgumentParser
19
from vllm.utils.network_utils import get_tcp_uri
20
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
21
from vllm.v1.executor import Executor
22
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
23
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
24
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
25
26

logger = init_logger(__name__)
27

28
29
30
31
32
33
34
35
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM
completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified.

Search by using: `--help=<ConfigGroup>` to explore options by section (e.g.,
--help=ModelConfig, --help=Frontend)
  Use `--help=all` to show all available flags at once.
"""

36
37

class ServeSubcommand(CLISubcommand):
38
39
    """The `serve` subcommand for the vLLM CLI."""

40
    name = "serve"
41
42
43

    @staticmethod
    def cmd(args: argparse.Namespace) -> None:
44
        # If model is specified in CLI (as positional arg), it takes precedence
45
        if hasattr(args, "model_tag") and args.model_tag is not None:
46
            args.model = args.model_tag
47

48
49
50
51
52
53
        if getattr(args, "grpc", False):
            from vllm.entrypoints.grpc_server import serve_grpc

            uvloop.run(serve_grpc(args))
            return

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        if args.headless:
            if args.api_server_count is not None and args.api_server_count > 0:
                raise ValueError(
                    f"--api-server-count={args.api_server_count} cannot be "
                    "used with --headless (no API servers are started in "
                    "headless mode)."
                )
            # Default to 0 in headless mode (no API servers)
            args.api_server_count = 0

        # Detect LB mode for defaulting api_server_count.
        # External LB: --data-parallel-external-lb or --data-parallel-rank
        # Hybrid LB: --data-parallel-hybrid-lb or --data-parallel-start-rank
        is_external_lb = (
            args.data_parallel_external_lb or args.data_parallel_rank is not None
        )
        is_hybrid_lb = (
            args.data_parallel_hybrid_lb or args.data_parallel_start_rank is not None
        )

        if is_external_lb and is_hybrid_lb:
            raise ValueError(
                "Cannot use both external and hybrid data parallel load "
                "balancing modes. External LB is enabled via "
                "--data-parallel-external-lb or --data-parallel-rank. "
                "Hybrid LB is enabled via --data-parallel-hybrid-lb or "
                "--data-parallel-start-rank. Use one mode or the other."
            )

        # Default api_server_count if not explicitly set.
        # - External LB: Leave as 1 (external LB handles distribution)
        # - Hybrid LB: Use local DP size (internal LB for local ranks only)
        # - Internal LB: Use full DP size
        if args.api_server_count is None:
            if is_external_lb:
                args.api_server_count = 1
            elif is_hybrid_lb:
                args.api_server_count = args.data_parallel_size_local or 1
                if args.api_server_count > 1:
                    logger.info(
                        "Defaulting api_server_count to data_parallel_size_local "
                        "(%d) for hybrid LB mode.",
                        args.api_server_count,
                    )
            else:
                args.api_server_count = args.data_parallel_size
                if args.api_server_count > 1:
                    logger.info(
                        "Defaulting api_server_count to data_parallel_size (%d).",
                        args.api_server_count,
                    )

106
107
108
109
110
111
112
113
114
        # Elastic EP currently only supports running with at most one API server.
        if getattr(args, "enable_elastic_ep", False) and args.api_server_count > 1:
            logger.warning(
                "Elastic EP only supports running with with at most one API server. "
                "Capping api_server_count from %d to 1.",
                args.api_server_count,
            )
            args.api_server_count = 1

115
        if args.api_server_count < 1:
116
            run_headless(args)
117
118
        elif args.api_server_count > 1:
            run_multi_api_server(args)
119
        else:
120
            # Single API server (this process).
121
            args.api_server_count = None
122
            uvloop.run(run_server(args))
123
124
125
126
127

    def validate(self, args: argparse.Namespace) -> None:
        validate_parsed_serve_args(args)

    def subparser_init(
128
129
        self, subparsers: argparse._SubParsersAction
    ) -> FlexibleArgumentParser:
130
        serve_parser = subparsers.add_parser(
131
132
133
134
135
            self.name,
            help="Launch a local OpenAI-compatible API server to serve LLM "
            "completions via HTTP.",
            description=DESCRIPTION,
            usage="vllm serve [model_tag] [options]",
136
        )
137

138
        serve_parser = make_arg_parser(serve_parser)
139
        serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
140
        return serve_parser
141
142


143
def cmd_init() -> list[CLISubcommand]:
144
    return [ServeSubcommand()]
145
146
147


def run_headless(args: argparse.Namespace):
148
149
    if args.api_server_count > 1:
        raise ValueError("api_server_count can't be set in headless mode")
150

151
    # Create the EngineConfig.
152
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
153
    usage_context = UsageContext.OPENAI_API_SERVER
154
155
156
    vllm_config = engine_args.create_engine_config(
        usage_context=usage_context, headless=True
    )
157

158
    if engine_args.data_parallel_hybrid_lb:
159
        raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode")
160

161
162
163
164
    parallel_config = vllm_config.parallel_config
    local_engine_count = parallel_config.data_parallel_size_local

    if local_engine_count <= 0:
165
        raise ValueError("data_parallel_size_local must be > 0 in headless mode")
166

167
    shutdown_requested = False
168

169
170
    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
171
        nonlocal shutdown_requested
172
        logger.debug("Received %d signal.", signum)
173
174
175
        if not shutdown_requested:
            shutdown_requested = True
            raise SystemExit
176
177
178
179

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    if parallel_config.node_rank_within_dp > 0:
        from vllm.version import __version__ as VLLM_VERSION

        # Run headless workers (for multi-node PP/TP).
        host = parallel_config.master_addr
        head_node_address = f"{host}:{parallel_config.master_port}"
        logger.info(
            "Launching vLLM (v%s) headless multiproc executor, "
            "with head node address %s for torch.distributed process group.",
            VLLM_VERSION,
            head_node_address,
        )

        executor = MultiprocExecutor(vllm_config, monitor_workers=False)
        executor.start_worker_monitor(inline=True)
        return

    host = parallel_config.data_parallel_master_ip
    port = parallel_config.data_parallel_rpc_port
    handshake_address = get_tcp_uri(host, port)

201
202
    logger.info(
        "Launching %d data parallel engine(s) in headless mode, "
203
204
205
206
        "with head node address %s.",
        local_engine_count,
        handshake_address,
    )
207
208
209
210

    # Create the engines.
    engine_manager = CoreEngineProcManager(
        local_engine_count=local_engine_count,
211
        start_index=vllm_config.parallel_config.data_parallel_rank,
212
213
        local_start_index=0,
        vllm_config=vllm_config,
214
        local_client=False,
215
        handshake_address=handshake_address,
216
217
218
219
220
        executor_class=Executor.get_class(vllm_config),
        log_stats=not engine_args.disable_log_stats,
    )

    try:
221
        engine_manager.monitor_engine_liveness()
222
    finally:
223
224
225
226
227
        timeout = None
        if shutdown_requested:
            timeout = vllm_config.shutdown_timeout
            logger.info("Waiting up to %d seconds for processes to exit", timeout)
        engine_manager.shutdown(timeout=timeout)
228
        logger.info("Shutting down.")
229
230
231
232


def run_multi_api_server(args: argparse.Namespace):
    assert not args.headless
233
    num_api_servers: int = args.api_server_count
234
235
236
237
238
    assert num_api_servers > 0

    if num_api_servers > 1:
        setup_multiprocess_prometheus()

239
240
241
242
243
244
245
246
247
248
249
250
251
    shutdown_requested = False

    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
        nonlocal shutdown_requested
        logger.debug("Received %d signal.", signum)
        if not shutdown_requested:
            shutdown_requested = True
            raise SystemExit

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

252
253
    listen_address, sock = setup_server(args)

254
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
255
256
257
    engine_args._api_process_count = num_api_servers
    engine_args._api_process_rank = -1

258
259
260
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

261
262
263
264
    if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
        raise ValueError(
            "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1"
        )
265

266
267
268
    executor_class = Executor.get_class(vllm_config)
    log_stats = not engine_args.disable_log_stats

269
    parallel_config = vllm_config.parallel_config
270
    dp_rank = parallel_config.data_parallel_rank
271
    assert parallel_config.local_engines_only or dp_rank == 0
272

273
    api_server_manager: APIServerProcessManager | None = None
274

275
276
277
278
    from vllm.v1.engine.utils import get_engine_zmq_addresses

    addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)

279
    with launch_core_engines(
280
        vllm_config, executor_class, log_stats, addresses, num_api_servers
281
    ) as (local_engine_manager, coordinator, addresses, tensor_queue):
282
        # Construct common args for the APIServerProcessManager up-front.
283
284
285
286
287
288
        stats_update_address = None
        if coordinator:
            stats_update_address = coordinator.get_stats_publish_address()

        # Start API servers.
        api_server_manager = APIServerProcessManager(
Rui Qiao's avatar
Rui Qiao committed
289
290
291
292
            listen_address=listen_address,
            sock=sock,
            args=args,
            num_servers=num_api_servers,
293
294
            input_addresses=addresses.inputs,
            output_addresses=addresses.outputs,
295
            stats_update_address=stats_update_address,
296
            tensor_queue=tensor_queue,
297
        )
298

299
    # Wait for API servers.
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    try:
        wait_for_completion_or_failure(
            api_server_manager=api_server_manager,
            engine_manager=local_engine_manager,
            coordinator=coordinator,
        )
    finally:
        timeout = shutdown_by = None
        if shutdown_requested:
            timeout = vllm_config.shutdown_timeout
            shutdown_by = time.monotonic() + timeout
            logger.info("Waiting up to %d seconds for processes to exit", timeout)

        def to_timeout(deadline: float | None) -> float | None:
            return (
                deadline if deadline is None else max(deadline - time.monotonic(), 0.0)
            )

        api_server_manager.shutdown(timeout=timeout)
        if local_engine_manager:
            local_engine_manager.shutdown(timeout=to_timeout(shutdown_by))
        if coordinator:
            coordinator.shutdown(timeout=to_timeout(shutdown_by))