api_server.py 19.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import importlib
import inspect
6
import multiprocessing
7
import multiprocessing.forkserver as forkserver
8
import os
9
import signal
10
import socket
11
import tempfile
12
import warnings
13
from argparse import Namespace
14
from collections.abc import AsyncIterator
15
from contextlib import asynccontextmanager
16
from typing import Any
17

18
import uvloop
19
from fastapi import FastAPI, HTTPException
Zhuohan Li's avatar
Zhuohan Li committed
20
21
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
22
from starlette.datastructures import State
Zhuohan Li's avatar
Zhuohan Li committed
23

24
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.engine.arg_utils import AsyncEngineArgs
26
from vllm.engine.protocol import EngineClient
27
from vllm.entrypoints.chat_utils import load_chat_template
28
from vllm.entrypoints.launcher import serve_http
29
from vllm.entrypoints.logger import RequestLogger
30
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
31
from vllm.entrypoints.openai.models.protocol import BaseModelPath
32
33
34
35
36
37
38
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.server_utils import (
    get_uvicorn_log_config,
    http_exception_handler,
    lifespan,
    log_response,
    validation_exception_handler,
39
)
40
from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap
41
42
43
44
from vllm.entrypoints.serve.elastic_ep.middleware import (
    ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
45
46
47
from vllm.entrypoints.utils import (
    cli_env_setup,
    log_non_default_args,
48
    log_version_and_model,
49
    process_lora_modules,
50
)
51
from vllm.logger import init_logger
52
from vllm.reasoning import ReasoningParserManager
53
from vllm.tasks import POOLING_TASKS, SupportedTask
54
from vllm.tool_parsers import ToolParserManager
55
from vllm.tracing import instrument
yhu422's avatar
yhu422 committed
56
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
57
from vllm.utils.argparse_utils import FlexibleArgumentParser
58
from vllm.utils.network_utils import is_valid_ipv6_address
59
from vllm.utils.system_utils import decorate_logs, set_ulimit
60
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
61

62
prometheus_multiproc_dir: tempfile.TemporaryDirectory
63

64
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
65
logger = init_logger("vllm.entrypoints.openai.api_server")
66

67
68
_FALLBACK_SUPPORTED_TASKS: tuple[SupportedTask, ...] = ("generate",)

69

70
@asynccontextmanager
71
async def build_async_engine_client(
72
    args: Namespace,
73
74
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
75
76
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
77
) -> AsyncIterator[EngineClient]:
78
79
80
81
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
82
        multiprocessing.set_start_method("forkserver")
83
84
85
86
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

87
    # Context manager to handle engine_client lifecycle
88
89
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
90
91
92
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
93

94
    if disable_frontend_multiprocessing is None:
95
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
96

97
    async with build_async_engine_client_from_engine_args(
98
99
100
101
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
102
    ) as engine:
103
104
105
106
107
108
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
109
110
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
111
    disable_frontend_multiprocessing: bool = False,
112
    client_config: dict[str, Any] | None = None,
113
) -> AsyncIterator[EngineClient]:
114
    """
115
    Create EngineClient, either:
116
117
118
119
120
121
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

122
123
124
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

125
    if disable_frontend_multiprocessing:
126
        logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
127

128
    from vllm.v1.engine.async_llm import AsyncLLM
129

130
    async_llm: AsyncLLM | None = None
131
132
133
134
135
136

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

137
138
139
140
141
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
142
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
143
144
145
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
146
147
            client_index=client_index,
        )
148
149

        # Don't keep the dummy data in memory
150
        assert async_llm is not None
151
152
153
154
155
156
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
157
158


159
160
161
162
163
164
165
166
167
168
169
170
171
def build_app(
    args: Namespace, supported_tasks: tuple["SupportedTask", ...] | None = None
) -> FastAPI:
    if supported_tasks is None:
        warnings.warn(
            "The 'supported_tasks' parameter was not provided to "
            "build_app and will be required in a future version. "
            "Defaulting to ('generate',).",
            DeprecationWarning,
            stacklevel=2,
        )
        supported_tasks = _FALLBACK_SUPPORTED_TASKS

172
    if args.disable_fastapi_docs:
173
174
175
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
176
177
    elif args.enable_offline_docs:
        app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
178
179
    else:
        app = FastAPI(lifespan=lifespan)
180
    app.state.args = args
181

182
    from vllm.entrypoints.serve import register_vllm_serve_api_routers
183

184
    register_vllm_serve_api_routers(app)
185

186
187
    from vllm.entrypoints.openai.models.api_router import (
        attach_router as register_models_api_router,
188
189
    )

190
    register_models_api_router(app)
191

192
193
    from vllm.entrypoints.sagemaker.api_router import (
        attach_router as register_sagemaker_api_router,
194
195
    )

196
    register_sagemaker_api_router(app, supported_tasks)
197

198
    if any(task in supported_tasks for task in ("generate", "render")):
199
200
201
        from vllm.entrypoints.openai.generate.api_router import (
            register_generate_api_routers,
        )
202

203
        register_generate_api_routers(app)
204

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        from vllm.entrypoints.serve.disagg.api_router import (
            attach_router as attach_disagg_router,
        )

        attach_disagg_router(app)

        from vllm.entrypoints.serve.rlhf.api_router import (
            attach_router as attach_rlhf_router,
        )

        attach_rlhf_router(app)

        from vllm.entrypoints.serve.elastic_ep.api_router import (
            attach_router as elastic_ep_attach_router,
        )

        elastic_ep_attach_router(app)

223
    if "transcription" in supported_tasks:
224
225
        from vllm.entrypoints.openai.speech_to_text.api_router import (
            attach_router as register_speech_to_text_api_router,
226
        )
227

228
        register_speech_to_text_api_router(app)
Zhuohan Li's avatar
Zhuohan Li committed
229

230
231
232
233
234
235
236
    if "realtime" in supported_tasks:
        from vllm.entrypoints.openai.realtime.api_router import (
            attach_router as register_realtime_api_router,
        )

        register_realtime_api_router(app)

237
238
    if any(task in POOLING_TASKS for task in supported_tasks):
        from vllm.entrypoints.pooling import register_pooling_api_routers
239

240
        register_pooling_api_routers(app, supported_tasks)
241

242
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
243
244
245
246
247
248
249
250
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

251
252
    app.exception_handler(HTTPException)(http_exception_handler)
    app.exception_handler(RequestValidationError)(validation_exception_handler)
Ethan Xu's avatar
Ethan Xu committed
253

254
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
255
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
256
257
        from vllm.entrypoints.openai.server_utils import AuthenticationMiddleware

258
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
259

260
    if args.enable_request_id_headers:
261
262
        from vllm.entrypoints.openai.server_utils import XRequestIdMiddleware

263
        app.add_middleware(XRequestIdMiddleware)
264

265
266
267
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

268
269
270
271
272
273
274
275
    if "realtime" in supported_tasks:
        # Add WebSocket metrics middleware
        from vllm.entrypoints.openai.realtime.metrics import (
            WebSocketMetricsMiddleware,
        )

        app.add_middleware(WebSocketMetricsMiddleware)

276
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
277
278
279
280
281
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
282
        app.middleware("http")(log_response)
283

284
285
286
287
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
288
            app.add_middleware(imported)  # type: ignore[arg-type]
289
290
291
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
292
293
294
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
295

296
    app = sagemaker_standards_bootstrap(app)
Ethan Xu's avatar
Ethan Xu committed
297
298
299
    return app


300
async def init_app_state(
301
    engine_client: EngineClient,
302
    state: State,
303
    args: Namespace,
304
    supported_tasks: tuple["SupportedTask", ...] | None = None,
305
) -> None:
306
    vllm_config = engine_client.vllm_config
307
308
309
310
311
312
313
314
315
    if supported_tasks is None:
        warnings.warn(
            "The 'supported_tasks' parameter was not provided to "
            "init_app_state and will be required in a future version. "
            "Please pass 'supported_tasks' explicitly.",
            DeprecationWarning,
            stacklevel=2,
        )
        supported_tasks = _FALLBACK_SUPPORTED_TASKS
316

317
    if args.served_model_name is not None:
318
        served_model_names = args.served_model_name
319
    else:
320
        served_model_names = [args.model]
321

322
    if args.enable_log_requests:
323
        request_logger = RequestLogger(max_log_len=args.max_log_len)
324
325
    else:
        request_logger = None
326

327
    base_model_paths = [
328
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
329
330
    ]

331
    state.engine_client = engine_client
332
    state.log_stats = not args.disable_log_stats
333
    state.vllm_config = vllm_config
334
    state.args = args
335
    resolved_chat_template = load_chat_template(args.chat_template)
336

337
    # Merge default_mm_loras into the static lora_modules
338
339
340
341
342
343
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
344

345
    state.openai_serving_models = OpenAIServingModels(
346
        engine_client=engine_client,
347
        base_model_paths=base_model_paths,
348
        lora_modules=lora_modules,
349
    )
350
    await state.openai_serving_models.init_static_loras()
351
    state.openai_serving_tokenization = OpenAIServingTokenization(
352
        engine_client,
353
        state.openai_serving_models,
354
        request_logger=request_logger,
355
356
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
357
        trust_request_chat_template=args.trust_request_chat_template,
358
        log_error_stack=args.log_error_stack,
359
    )
360

361
    if any(task in supported_tasks for task in ("generate", "render")):
362
363
364
365
        from vllm.entrypoints.openai.generate.api_router import init_generate_state

        await init_generate_state(
            engine_client, state, args, request_logger, supported_tasks
366
        )
367
368

    if "transcription" in supported_tasks:
369
        from vllm.entrypoints.openai.speech_to_text.api_router import (
370
            init_transcription_state,
371
        )
372
373
374

        init_transcription_state(
            engine_client, state, args, request_logger, supported_tasks
375
        )
376

377
378
379
380
381
    if "realtime" in supported_tasks:
        from vllm.entrypoints.openai.realtime.api_router import init_realtime_state

        init_realtime_state(engine_client, state, args, request_logger, supported_tasks)

382
383
    if any(task in POOLING_TASKS for task in supported_tasks):
        from vllm.entrypoints.pooling import init_pooling_state
384

385
        init_pooling_state(engine_client, state, args, request_logger, supported_tasks)
386

387
388
389
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

390

391
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
392
393
394
395
396
397
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
398
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
399
400
401
402
403
    sock.bind(addr)

    return sock


404
405
406
407
408
409
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


410
def validate_api_server_args(args):
411
    valid_tool_parses = ToolParserManager.list_registered()
412
413
414
415
416
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
417

418
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
419
420
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
421
    ) and reasoning_parser not in valid_reasoning_parsers:
422
        raise KeyError(
423
            f"invalid reasoning parser: {reasoning_parser} "
424
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
425
        )
426

427

428
@instrument(span_name="API server setup")
429
430
431
432
def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

433
    log_version_and_model(logger, VLLM_VERSION, args.model)
434
435
436
437
438
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

439
440
441
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

442
443
    validate_api_server_args(args)

444
445
446
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
447
448
449
450
451
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
452

453
454
455
456
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

457
458
459
460
461
462
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

463
464
465
466
467
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
468
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
469
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
470
471
472
    return listen_address, sock


473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
async def build_and_serve(
    engine_client: EngineClient,
    listen_address: str,
    sock: socket.socket,
    args: Namespace,
    **uvicorn_kwargs,
) -> asyncio.Task:
    """Build FastAPI app, initialize state, and start serving.

    Returns the shutdown task for the caller to await.
    """

    # Get uvicorn log config (from file or with endpoint filter)
    log_config = get_uvicorn_log_config(args)
    if log_config is not None:
        uvicorn_kwargs["log_config"] = log_config

    supported_tasks = await engine_client.get_supported_tasks()
    logger.info("Supported tasks: %s", supported_tasks)

    app = build_app(args, supported_tasks)
    await init_app_state(engine_client, app.state, args, supported_tasks)

    logger.info("Starting vLLM server on %s", listen_address)

    return await serve_http(
        app,
        sock=sock,
        enable_ssl_refresh=args.enable_ssl_refresh,
        host=args.host,
        port=args.port,
        log_level=args.uvicorn_log_level,
        # NOTE: When the 'disable_uvicorn_access_log' value is True,
        # no access log will be output.
        access_log=not args.disable_uvicorn_access_log,
        timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
        ssl_keyfile=args.ssl_keyfile,
        ssl_certfile=args.ssl_certfile,
        ssl_ca_certs=args.ssl_ca_certs,
        ssl_cert_reqs=args.ssl_cert_reqs,
        ssl_ciphers=args.ssl_ciphers,
        h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
        h11_max_header_count=args.h11_max_header_count,
        **uvicorn_kwargs,
    )


520
521
async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
522
523

    # Add process-specific prefix to stdout and stderr.
524
    decorate_logs("APIServer")
525

526
527
528
529
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


530
531
532
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
533
534
535
536
537
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

538
539
540
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

541
    async with build_async_engine_client(
542
543
        args,
        client_config=client_config,
544
    ) as engine_client:
545
546
        shutdown_task = await build_and_serve(
            engine_client, listen_address, sock, args, **uvicorn_kwargs
547
        )
548
    # NB: Await server shutdown only after the backend context is exited
549
550
551
552
    try:
        await shutdown_task
    finally:
        sock.close()
553

Ethan Xu's avatar
Ethan Xu committed
554
555
556

if __name__ == "__main__":
    # NOTE(simon):
557
558
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
559
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
560
    parser = FlexibleArgumentParser(
561
562
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
563
564
    parser = make_arg_parser(parser)
    args = parser.parse_args()
565
    validate_parsed_serve_args(args)
566

567
    uvloop.run(run_server(args))