serve.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import argparse
5
import os
6
import signal
7
import sys
8
9

import uvloop
10
import zmq
11

12
13
import vllm.envs as envs
from vllm import AsyncEngineArgs
14
from vllm.entrypoints.cli.types import CLISubcommand
15
16
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
                                                setup_server)
17
18
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
                                              validate_parsed_serve_args)
19
20
from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG,
                                    show_filtered_argument_or_group_from_help)
21
from vllm.executor.multiproc_worker_utils import _add_prefix
22
23
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
24
25
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
from vllm.v1.engine.coordinator import DPCoordinator
26
27
28
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
29
30
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
Rui Qiao's avatar
Rui Qiao committed
31
32
                           CoreEngineActorManager, EngineZmqAddresses,
                           get_engine_client_zmq_addr,
33
34
                           wait_for_completion_or_failure,
                           wait_for_engine_startup)
35
36

logger = init_logger(__name__)
37
38
39
40
41
42
43
44
45
46
47


class ServeSubcommand(CLISubcommand):
    """The `serve` subcommand for the vLLM CLI. """

    def __init__(self):
        self.name = "serve"
        super().__init__()

    @staticmethod
    def cmd(args: argparse.Namespace) -> None:
48
49
50
        # If model is specified in CLI (as positional arg), it takes precedence
        if hasattr(args, 'model_tag') and args.model_tag is not None:
            args.model = args.model_tag
51

52
        if args.headless or args.api_server_count < 1:
53
            run_headless(args)
54
55
        elif args.api_server_count > 1:
            run_multi_api_server(args)
56
        else:
57
            # Single API server (this process).
58
            uvloop.run(run_server(args))
59
60
61
62
63
64
65
66
67

    def validate(self, args: argparse.Namespace) -> None:
        validate_parsed_serve_args(args)

    def subparser_init(
            self,
            subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
        serve_parser = subparsers.add_parser(
            "serve",
68
69
            help="Start the vLLM OpenAI Compatible API server.",
            description="Start the vLLM OpenAI Compatible API server.",
70
            usage="vllm serve [model_tag] [options]")
71
72
        serve_parser.add_argument("model_tag",
                                  type=str,
73
74
75
                                  nargs='?',
                                  help="The model tag to serve "
                                  "(optional if specified in config)")
76
77
78
79
80
81
82
83
84
85
86
87
        serve_parser.add_argument(
            "--headless",
            action='store_true',
            default=False,
            help="Run in headless mode. See multi-node data parallel "
            "documentation for more details.")
        serve_parser.add_argument(
            '--data-parallel-start-rank',
            '-dpr',
            type=int,
            default=0,
            help='Starting data parallel rank for secondary nodes.')
88
89
90
91
92
        serve_parser.add_argument('--api-server-count',
                                  '-asc',
                                  type=int,
                                  default=1,
                                  help='How many API server processes to run.')
93
94
95
96
97
98
99
100
101
102
        serve_parser.add_argument(
            "--config",
            type=str,
            default='',
            required=False,
            help="Read CLI options from a config file."
            "Must be a YAML with the following options:"
            "https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference"
        )

103
104
105
106
        serve_parser = make_arg_parser(serve_parser)
        show_filtered_argument_or_group_from_help(serve_parser)
        serve_parser.epilog = VLLM_SERVE_PARSER_EPILOG
        return serve_parser
107
108


109
def cmd_init() -> list[CLISubcommand]:
110
    return [ServeSubcommand()]
111
112
113
114


def run_headless(args: argparse.Namespace):

115
116
117
    if args.api_server_count > 1:
        raise ValueError("api_server_count can't be set in headless mode")

118
119
120
121
122
123
    # Create the EngineConfig.
    engine_args = AsyncEngineArgs.from_cli_args(args)
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

    if not envs.VLLM_USE_V1:
124
        raise ValueError("Headless mode is only supported for V1")
125
126
127
128
129

    parallel_config = vllm_config.parallel_config
    local_engine_count = parallel_config.data_parallel_size_local
    host = parallel_config.data_parallel_master_ip
    port = engine_args.data_parallel_rpc_port  # add to config too
130
    handshake_address = get_tcp_uri(host, port)
131
132

    if local_engine_count <= 0:
133
134
        raise ValueError("data_parallel_size_local must be > 0 in "
                         "headless mode")
135
136
137
138
139
140
141
142
143
144
145

    # Catch SIGTERM and SIGINT to allow graceful shutdown.
    def signal_handler(signum, frame):
        logger.debug("Received %d signal.", signum)
        raise SystemExit

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

    logger.info(
        "Launching %d data parallel engine(s) in headless mode, "
146
        "with head node address %s.", local_engine_count, handshake_address)
147
148
149
150
151
152
153
154
155

    # Create the engines.
    engine_manager = CoreEngineProcManager(
        target_fn=EngineCoreProc.run_engine_core,
        local_engine_count=local_engine_count,
        start_index=args.data_parallel_start_rank,
        local_start_index=0,
        vllm_config=vllm_config,
        on_head_node=False,
156
        handshake_address=handshake_address,
157
158
159
160
161
162
163
164
165
        executor_class=Executor.get_class(vllm_config),
        log_stats=not engine_args.disable_log_stats,
    )

    try:
        engine_manager.join_first()
    finally:
        logger.info("Shutting down.")
        engine_manager.close()
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233


def run_multi_api_server(args: argparse.Namespace):

    assert not args.headless
    num_api_servers = args.api_server_count
    assert num_api_servers > 0

    if num_api_servers > 1:
        setup_multiprocess_prometheus()

    listen_address, sock = setup_server(args)

    engine_args = AsyncEngineArgs.from_cli_args(args)
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)
    model_config = vllm_config.model_config

    if num_api_servers > 1:
        if not envs.VLLM_USE_V1:
            raise ValueError("api_server_count > 1 is only supported for V1")

        if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
            raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
                             "with api_server_count > 1")

        if model_config.is_multimodal_model and not (
                model_config.disable_mm_preprocessor_cache):
            logger.warning(
                "Multi-model preprocessor cache will be disabled for"
                " api_server_count > 1")
            model_config.disable_mm_preprocessor_cache = True

    parallel_config = vllm_config.parallel_config

    assert parallel_config.data_parallel_rank == 0

    dp_size = parallel_config.data_parallel_size
    local_engine_count = parallel_config.data_parallel_size_local
    host = parallel_config.data_parallel_master_ip
    local_only = local_engine_count == dp_size

    # Set up input and output addresses.
    input_addresses = [
        get_engine_client_zmq_addr(local_only, host)
        for _ in range(num_api_servers)
    ]
    output_addresses = [
        get_engine_client_zmq_addr(local_only, host)
        for _ in range(num_api_servers)
    ]

    addresses = EngineZmqAddresses(
        inputs=input_addresses,
        outputs=output_addresses,
    )

    # Set up coordinator for dp > 1.
    coordinator = None
    stats_update_address = None
    if dp_size > 1:
        coordinator = DPCoordinator(parallel_config)
        addresses.coordinator_input, addresses.coordinator_output = (
            coordinator.get_engine_socket_addresses())
        stats_update_address = coordinator.get_stats_publish_address()
        logger.info("Started DP Coordinator process (PID: %d)",
                    coordinator.proc.pid)

Rui Qiao's avatar
Rui Qiao committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    if parallel_config.data_parallel_backend == "ray":
        logger.info("Starting ray-based data parallel backend")

        engine_actor_manager = CoreEngineActorManager(
            vllm_config=vllm_config,
            addresses=addresses,
            executor_class=Executor.get_class(vllm_config),
            log_stats=not engine_args.disable_log_stats,
        )
        # Start API servers using the manager
        api_server_manager = APIServerProcessManager(
            target_server_fn=run_api_server_worker_proc,
            listen_address=listen_address,
            sock=sock,
            args=args,
            num_servers=num_api_servers,
            input_addresses=input_addresses,
            output_addresses=output_addresses,
            stats_update_address=stats_update_address)

        wait_for_completion_or_failure(api_server_manager=api_server_manager,
                                       engine_manager=engine_actor_manager,
                                       coordinator=coordinator)
        return

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    handshake_address = get_engine_client_zmq_addr(
        local_only, host, parallel_config.data_parallel_rpc_port)

    with zmq_socket_ctx(handshake_address, zmq.ROUTER,
                        bind=True) as handshake_socket:

        # Start local engines.
        if not local_engine_count:
            local_engine_manager = None
        else:
            local_engine_manager = CoreEngineProcManager(
                EngineCoreProc.run_engine_core,
                vllm_config=vllm_config,
                executor_class=Executor.get_class(vllm_config),
                log_stats=not engine_args.disable_log_stats,
                handshake_address=handshake_address,
                on_head_node=True,
                local_engine_count=local_engine_count,
                start_index=0,
                local_start_index=0)

        # Start API servers using the manager
        api_server_manager = APIServerProcessManager(
            target_server_fn=run_api_server_worker_proc,
            listen_address=listen_address,
            sock=sock,
            args=args,
            num_servers=num_api_servers,
            input_addresses=input_addresses,
            output_addresses=output_addresses,
            stats_update_address=stats_update_address)

        # Wait for engine handshakes to complete.
        core_engines = [
            CoreEngine(index=i, local=(i < local_engine_count))
            for i in range(dp_size)
        ]
        wait_for_engine_startup(
            handshake_socket,
            addresses,
            core_engines,
            parallel_config,
            vllm_config.cache_config,
            local_engine_manager,
            coordinator.proc if coordinator else None,
        )

        # Wait for API servers
Rui Qiao's avatar
Rui Qiao committed
307
308
309
        wait_for_completion_or_failure(api_server_manager=api_server_manager,
                                       engine_manager=local_engine_manager,
                                       coordinator=coordinator)
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328


def run_api_server_worker_proc(listen_address,
                               sock,
                               args,
                               client_config=None,
                               **uvicorn_kwargs) -> None:
    """Entrypoint for individual API server worker processes."""

    # Add process-specific prefix to stdout and stderr.
    from multiprocessing import current_process
    process_name = current_process().name
    pid = os.getpid()
    _add_prefix(sys.stdout, process_name, pid)
    _add_prefix(sys.stderr, process_name, pid)

    uvloop.run(
        run_server_worker(listen_address, sock, args, client_config,
                          **uvicorn_kwargs))