serve.py 10.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import argparse
5
import os
6
import signal
7
import sys
8
from typing import Optional
9
10
11

import uvloop

12
import vllm
13
import vllm.envs as envs
14
from vllm.entrypoints.cli.types import CLISubcommand
15
16
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
                                                setup_server)
17
18
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
                                              validate_parsed_serve_args)
19
from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG,
20
                                    show_filtered_argument_or_group_from_help)
21
from vllm.executor.multiproc_worker_utils import _add_prefix
22
23
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
24
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
25
from vllm.v1.engine.core import EngineCoreProc
26
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
27
from vllm.v1.executor.abstract import Executor
28
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
29
30
from vllm.v1.utils import (APIServerProcessManager,
                           wait_for_completion_or_failure)
31
32

logger = init_logger(__name__)
33
34
35
36


class ServeSubcommand(CLISubcommand):
    """The `serve` subcommand for the vLLM CLI. """
37
    name = "serve"
38
39
40

    @staticmethod
    def cmd(args: argparse.Namespace) -> None:
41
42
43
        # If model is specified in CLI (as positional arg), it takes precedence
        if hasattr(args, 'model_tag') and args.model_tag is not None:
            args.model = args.model_tag
44

45
        if args.headless or args.api_server_count < 1:
46
47
            run_headless(args)
        else:
48
49
50
51
52
53
54
55
            if args.data_parallel_start_rank:
                raise ValueError("data_parallel_start_rank is only "
                                 "applicable in headless mode")
            if args.api_server_count > 1:
                run_multi_api_server(args)
            else:
                # Single API server (this process).
                uvloop.run(run_server(args))
56
57
58
59
60
61
62
63
64

    def validate(self, args: argparse.Namespace) -> None:
        validate_parsed_serve_args(args)

    def subparser_init(
            self,
            subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
        serve_parser = subparsers.add_parser(
            "serve",
65
66
            help="Start the vLLM OpenAI Compatible API server.",
            description="Start the vLLM OpenAI Compatible API server.",
67
            usage="vllm serve [model_tag] [options]")
68
69
        serve_parser.add_argument("model_tag",
                                  type=str,
70
71
72
                                  nargs='?',
                                  help="The model tag to serve "
                                  "(optional if specified in config)")
73
74
75
76
77
78
79
80
81
82
83
84
        serve_parser.add_argument(
            "--headless",
            action='store_true',
            default=False,
            help="Run in headless mode. See multi-node data parallel "
            "documentation for more details.")
        serve_parser.add_argument(
            '--data-parallel-start-rank',
            '-dpr',
            type=int,
            default=0,
            help='Starting data parallel rank for secondary nodes.')
85
86
87
88
89
        serve_parser.add_argument('--api-server-count',
                                  '-asc',
                                  type=int,
                                  default=1,
                                  help='How many API server processes to run.')
90
91
92
93
94
        serve_parser.add_argument(
            "--config",
            type=str,
            default='',
            required=False,
95
96
97
            help="Read CLI options from a config file. "
            "Must be a YAML with the following options: "
            "https://docs.vllm.ai/en/latest/configuration/serve_args.html")
98

99
        serve_parser = make_arg_parser(serve_parser)
100
101
        show_filtered_argument_or_group_from_help(serve_parser, "serve")
        serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG
102
        return serve_parser
103
104


105
def cmd_init() -> list[CLISubcommand]:
106
    return [ServeSubcommand()]
107
108
109
110


def run_headless(args: argparse.Namespace):

111
112
113
    if args.api_server_count > 1:
        raise ValueError("api_server_count can't be set in headless mode")

114
    # Create the EngineConfig.
115
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
116
117
118
119
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

    if not envs.VLLM_USE_V1:
120
        raise ValueError("Headless mode is only supported for V1")
121
122
123
124
125

    parallel_config = vllm_config.parallel_config
    local_engine_count = parallel_config.data_parallel_size_local

    if local_engine_count <= 0:
126
127
        raise ValueError("data_parallel_size_local must be > 0 in "
                         "headless mode")
128

129
130
131
132
133
134
135
136
    if parallel_config.data_parallel_rank is not None:
        raise ValueError("data_parallel_rank is not applicable in "
                         "headless mode")

    host = parallel_config.data_parallel_master_ip
    port = engine_args.data_parallel_rpc_port  # add to config too
    handshake_address = get_tcp_uri(host, port)

137
138
139
140
141
142
143
144
145
146
    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
        logger.debug("Received %d signal.", signum)
        raise SystemExit

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

    logger.info(
        "Launching %d data parallel engine(s) in headless mode, "
147
        "with head node address %s.", local_engine_count, handshake_address)
148
149
150
151
152
153
154
155

    # Create the engines.
    engine_manager = CoreEngineProcManager(
        target_fn=EngineCoreProc.run_engine_core,
        local_engine_count=local_engine_count,
        start_index=args.data_parallel_start_rank,
        local_start_index=0,
        vllm_config=vllm_config,
156
        local_client=False,
157
        handshake_address=handshake_address,
158
159
160
161
162
163
164
165
166
        executor_class=Executor.get_class(vllm_config),
        log_stats=not engine_args.disable_log_stats,
    )

    try:
        engine_manager.join_first()
    finally:
        logger.info("Shutting down.")
        engine_manager.close()
167
168
169
170
171
172
173
174
175
176
177
178
179


def run_multi_api_server(args: argparse.Namespace):

    assert not args.headless
    num_api_servers = args.api_server_count
    assert num_api_servers > 0

    if num_api_servers > 1:
        setup_multiprocess_prometheus()

    listen_address, sock = setup_server(args)

180
    engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)
    model_config = vllm_config.model_config

    if num_api_servers > 1:
        if not envs.VLLM_USE_V1:
            raise ValueError("api_server_count > 1 is only supported for V1")

        if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
            raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
                             "with api_server_count > 1")

        if model_config.is_multimodal_model and not (
                model_config.disable_mm_preprocessor_cache):
            logger.warning(
                "Multi-model preprocessor cache will be disabled for"
                " api_server_count > 1")
            model_config.disable_mm_preprocessor_cache = True

200
201
202
    executor_class = Executor.get_class(vllm_config)
    log_stats = not engine_args.disable_log_stats

203
    parallel_config = vllm_config.parallel_config
204
205
206
    dp_rank = parallel_config.data_parallel_rank
    external_dp_lb = parallel_config.data_parallel_external_lb
    assert external_dp_lb or dp_rank == 0
207

208
    api_server_manager: Optional[APIServerProcessManager] = None
209

210
211
212
    with launch_core_engines(vllm_config, executor_class, log_stats,
                             num_api_servers) as (local_engine_manager,
                                                  coordinator, addresses):
213

214
215
        # Construct common args for the APIServerProcessManager up-front.
        api_server_manager_kwargs = dict(
Rui Qiao's avatar
Rui Qiao committed
216
217
218
219
220
            target_server_fn=run_api_server_worker_proc,
            listen_address=listen_address,
            sock=sock,
            args=args,
            num_servers=num_api_servers,
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            input_addresses=addresses.inputs,
            output_addresses=addresses.outputs,
            stats_update_address=coordinator.get_stats_publish_address()
            if coordinator else None)

        # For dp ranks > 0 in external DP LB mode, we must delay the
        # start of the API servers until the local engine is started
        # (after the launcher context manager exits),
        # since we get the front-end stats update address from the coordinator
        # via the handshake with the local engine.
        if dp_rank == 0 or not external_dp_lb:
            # Start API servers using the manager.
            api_server_manager = APIServerProcessManager(
                **api_server_manager_kwargs)

    # Start API servers now if they weren't already started.
    if api_server_manager is None:
        api_server_manager_kwargs["stats_update_address"] = (
            addresses.frontend_stats_publish_address)
240
        api_server_manager = APIServerProcessManager(
241
242
243
244
245
246
            **api_server_manager_kwargs)

    # Wait for API servers
    wait_for_completion_or_failure(api_server_manager=api_server_manager,
                                   engine_manager=local_engine_manager,
                                   coordinator=coordinator)
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265


def run_api_server_worker_proc(listen_address,
                               sock,
                               args,
                               client_config=None,
                               **uvicorn_kwargs) -> None:
    """Entrypoint for individual API server worker processes."""

    # Add process-specific prefix to stdout and stderr.
    from multiprocessing import current_process
    process_name = current_process().name
    pid = os.getpid()
    _add_prefix(sys.stdout, process_name, pid)
    _add_prefix(sys.stderr, process_name, pid)

    uvloop.run(
        run_server_worker(listen_address, sock, args, client_config,
                          **uvicorn_kwargs))