serve.py 8.81 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import argparse
5
import signal
6
7
8

import uvloop

9
import vllm
10
import vllm.envs as envs
11
from vllm.entrypoints.cli.types import CLISubcommand
12
13
14
15
16
17
from vllm.entrypoints.openai.api_server import (
    run_server,
    run_server_worker,
    setup_server,
)
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
18
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
19
20
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
21
from vllm.utils.argparse_utils import FlexibleArgumentParser
22
from vllm.utils.network_utils import get_tcp_uri
23
from vllm.utils.system_utils import decorate_logs, set_process_title
24
from vllm.v1.engine.core import EngineCoreProc
25
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
26
from vllm.v1.executor import Executor
27
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
28
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
29
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
30
31

logger = init_logger(__name__)
32

33
34
35
36
37
38
39
40
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM
completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified.

Search by using: `--help=<ConfigGroup>` to explore options by section (e.g.,
--help=ModelConfig, --help=Frontend)
  Use `--help=all` to show all available flags at once.
"""

41
42

class ServeSubcommand(CLISubcommand):
43
44
    """The `serve` subcommand for the vLLM CLI."""

45
    name = "serve"
46
47
48

    @staticmethod
    def cmd(args: argparse.Namespace) -> None:
49
        # If model is specified in CLI (as positional arg), it takes precedence
50
        if hasattr(args, "model_tag") and args.model_tag is not None:
51
            args.model = args.model_tag
52

53
        if args.headless or args.api_server_count < 1:
54
55
            run_headless(args)
        else:
56
57
58
59
60
            if args.api_server_count > 1:
                run_multi_api_server(args)
            else:
                # Single API server (this process).
                uvloop.run(run_server(args))
61
62
63
64
65

    def validate(self, args: argparse.Namespace) -> None:
        validate_parsed_serve_args(args)

    def subparser_init(
66
67
        self, subparsers: argparse._SubParsersAction
    ) -> FlexibleArgumentParser:
68
        serve_parser = subparsers.add_parser(
69
70
71
72
73
            self.name,
            help="Launch a local OpenAI-compatible API server to serve LLM "
            "completions via HTTP.",
            description=DESCRIPTION,
            usage="vllm serve [model_tag] [options]",
74
        )
75

76
        serve_parser = make_arg_parser(serve_parser)
77
        serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
78
        return serve_parser
79
80


81
def cmd_init() -> list[CLISubcommand]:
82
    return [ServeSubcommand()]
83
84
85


def run_headless(args: argparse.Namespace):
86
87
    if args.api_server_count > 1:
        raise ValueError("api_server_count can't be set in headless mode")
88

89
    # Create the EngineConfig.
90
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
91
    usage_context = UsageContext.OPENAI_API_SERVER
92
93
94
    vllm_config = engine_args.create_engine_config(
        usage_context=usage_context, headless=True
    )
95

96
    if engine_args.data_parallel_hybrid_lb:
97
        raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode")
98

99
100
101
102
    parallel_config = vllm_config.parallel_config
    local_engine_count = parallel_config.data_parallel_size_local

    if local_engine_count <= 0:
103
        raise ValueError("data_parallel_size_local must be > 0 in headless mode")
104

105
    shutdown_requested = False
106

107
108
    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
109
        nonlocal shutdown_requested
110
        logger.debug("Received %d signal.", signum)
111
112
113
        if not shutdown_requested:
            shutdown_requested = True
            raise SystemExit
114
115
116
117

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    if parallel_config.node_rank_within_dp > 0:
        from vllm.version import __version__ as VLLM_VERSION

        # Run headless workers (for multi-node PP/TP).
        host = parallel_config.master_addr
        head_node_address = f"{host}:{parallel_config.master_port}"
        logger.info(
            "Launching vLLM (v%s) headless multiproc executor, "
            "with head node address %s for torch.distributed process group.",
            VLLM_VERSION,
            head_node_address,
        )

        executor = MultiprocExecutor(vllm_config, monitor_workers=False)
        executor.start_worker_monitor(inline=True)
        return

    host = parallel_config.data_parallel_master_ip
    port = parallel_config.data_parallel_rpc_port
    handshake_address = get_tcp_uri(host, port)

139
140
    logger.info(
        "Launching %d data parallel engine(s) in headless mode, "
141
142
143
144
        "with head node address %s.",
        local_engine_count,
        handshake_address,
    )
145
146
147
148
149

    # Create the engines.
    engine_manager = CoreEngineProcManager(
        target_fn=EngineCoreProc.run_engine_core,
        local_engine_count=local_engine_count,
150
        start_index=vllm_config.parallel_config.data_parallel_rank,
151
152
        local_start_index=0,
        vllm_config=vllm_config,
153
        local_client=False,
154
        handshake_address=handshake_address,
155
156
157
158
159
160
161
162
163
        executor_class=Executor.get_class(vllm_config),
        log_stats=not engine_args.disable_log_stats,
    )

    try:
        engine_manager.join_first()
    finally:
        logger.info("Shutting down.")
        engine_manager.close()
164
165
166
167


def run_multi_api_server(args: argparse.Namespace):
    assert not args.headless
168
    num_api_servers: int = args.api_server_count
169
170
171
172
173
174
175
    assert num_api_servers > 0

    if num_api_servers > 1:
        setup_multiprocess_prometheus()

    listen_address, sock = setup_server(args)

176
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
177
178
179
    engine_args._api_process_count = num_api_servers
    engine_args._api_process_rank = -1

180
181
182
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

183
184
185
186
    if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
        raise ValueError(
            "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1"
        )
187

188
189
190
    executor_class = Executor.get_class(vllm_config)
    log_stats = not engine_args.disable_log_stats

191
    parallel_config = vllm_config.parallel_config
192
    dp_rank = parallel_config.data_parallel_rank
193
    assert parallel_config.local_engines_only or dp_rank == 0
194

195
    api_server_manager: APIServerProcessManager | None = None
196

197
198
199
    with launch_core_engines(
        vllm_config, executor_class, log_stats, num_api_servers
    ) as (local_engine_manager, coordinator, addresses):
200
201
        # Construct common args for the APIServerProcessManager up-front.
        api_server_manager_kwargs = dict(
Rui Qiao's avatar
Rui Qiao committed
202
203
204
205
206
            target_server_fn=run_api_server_worker_proc,
            listen_address=listen_address,
            sock=sock,
            args=args,
            num_servers=num_api_servers,
207
208
209
            input_addresses=addresses.inputs,
            output_addresses=addresses.outputs,
            stats_update_address=coordinator.get_stats_publish_address()
210
211
212
            if coordinator
            else None,
        )
213

214
        # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
215
216
217
218
        # start of the API servers until the local engine is started
        # (after the launcher context manager exits),
        # since we get the front-end stats update address from the coordinator
        # via the handshake with the local engine.
219
        if dp_rank == 0 or not parallel_config.local_engines_only:
220
            # Start API servers using the manager.
221
            api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
222
223
224
225

    # Start API servers now if they weren't already started.
    if api_server_manager is None:
        api_server_manager_kwargs["stats_update_address"] = (
226
227
228
            addresses.frontend_stats_publish_address
        )
        api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
229
230

    # Wait for API servers
231
232
233
234
235
    wait_for_completion_or_failure(
        api_server_manager=api_server_manager,
        engine_manager=local_engine_manager,
        coordinator=coordinator,
    )
236
237


238
239
240
def run_api_server_worker_proc(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
241
    """Entrypoint for individual API server worker processes."""
242
243
    client_config = client_config or {}
    server_index = client_config.get("client_index", 0)
244

245
246
    # Set process title and add process-specific prefix to stdout and stderr.
    set_process_title("APIServer", str(server_index))
247
    decorate_logs()
248
249

    uvloop.run(
250
251
        run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)
    )