api_server.py 19.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import importlib
import inspect
6
import multiprocessing
7
import multiprocessing.forkserver as forkserver
8
import os
9
import signal
10
import socket
11
import tempfile
12
import warnings
13
from argparse import Namespace
14
from collections.abc import AsyncIterator
15
from contextlib import asynccontextmanager
16
from typing import Any
17

18
import uvloop
19
from fastapi import FastAPI, HTTPException
Zhuohan Li's avatar
Zhuohan Li committed
20
21
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
22
from starlette.datastructures import State
Zhuohan Li's avatar
Zhuohan Li committed
23

24
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.engine.arg_utils import AsyncEngineArgs
26
from vllm.engine.protocol import EngineClient
27
from vllm.entrypoints.chat_utils import load_chat_template
28
from vllm.entrypoints.launcher import serve_http
29
from vllm.entrypoints.logger import RequestLogger
30
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
31
from vllm.entrypoints.openai.models.protocol import BaseModelPath
32
33
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.server_utils import (
34
35
    engine_error_handler,
    exception_handler,
36
37
38
39
40
    get_uvicorn_log_config,
    http_exception_handler,
    lifespan,
    log_response,
    validation_exception_handler,
41
)
42
from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap
43
44
45
46
from vllm.entrypoints.serve.elastic_ep.middleware import (
    ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
47
48
49
from vllm.entrypoints.utils import (
    cli_env_setup,
    log_non_default_args,
50
    log_version_and_model,
51
    process_lora_modules,
52
)
53
from vllm.logger import init_logger
54
from vllm.reasoning import ReasoningParserManager
55
from vllm.tasks import POOLING_TASKS, SupportedTask
56
from vllm.tool_parsers import ToolParserManager
57
from vllm.tracing import instrument
yhu422's avatar
yhu422 committed
58
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
59
from vllm.utils.argparse_utils import FlexibleArgumentParser
60
from vllm.utils.network_utils import is_valid_ipv6_address
61
from vllm.utils.system_utils import decorate_logs, set_ulimit
62
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
63
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
64

65
prometheus_multiproc_dir: tempfile.TemporaryDirectory
66

67
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
68
logger = init_logger("vllm.entrypoints.openai.api_server")
69

70
71
_FALLBACK_SUPPORTED_TASKS: tuple[SupportedTask, ...] = ("generate",)

72

73
@asynccontextmanager
74
async def build_async_engine_client(
75
    args: Namespace,
76
77
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
78
79
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
80
) -> AsyncIterator[EngineClient]:
81
82
83
84
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
85
        multiprocessing.set_start_method("forkserver")
86
87
88
89
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

90
    # Context manager to handle engine_client lifecycle
91
92
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
93
94
95
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
96

97
    if disable_frontend_multiprocessing is None:
98
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
99

100
    async with build_async_engine_client_from_engine_args(
101
102
103
104
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
105
    ) as engine:
106
107
108
109
110
111
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
112
113
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
114
    disable_frontend_multiprocessing: bool = False,
115
    client_config: dict[str, Any] | None = None,
116
) -> AsyncIterator[EngineClient]:
117
    """
118
    Create EngineClient, either:
119
120
121
122
123
124
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

125
126
127
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

128
    if disable_frontend_multiprocessing:
129
        logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
130

131
    from vllm.v1.engine.async_llm import AsyncLLM
132

133
    async_llm: AsyncLLM | None = None
134
135
136
137
138
139

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

140
141
142
143
144
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
145
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
146
147
148
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
149
150
            client_index=client_index,
        )
151
152

        # Don't keep the dummy data in memory
153
        assert async_llm is not None
154
155
156
157
158
159
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
160
161


162
163
164
165
166
167
168
169
170
171
172
173
174
def build_app(
    args: Namespace, supported_tasks: tuple["SupportedTask", ...] | None = None
) -> FastAPI:
    if supported_tasks is None:
        warnings.warn(
            "The 'supported_tasks' parameter was not provided to "
            "build_app and will be required in a future version. "
            "Defaulting to ('generate',).",
            DeprecationWarning,
            stacklevel=2,
        )
        supported_tasks = _FALLBACK_SUPPORTED_TASKS

175
    if args.disable_fastapi_docs:
176
177
178
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
179
180
    elif args.enable_offline_docs:
        app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
181
182
    else:
        app = FastAPI(lifespan=lifespan)
183
    app.state.args = args
184

185
    from vllm.entrypoints.serve import register_vllm_serve_api_routers
186

187
    register_vllm_serve_api_routers(app)
188

189
190
    from vllm.entrypoints.openai.models.api_router import (
        attach_router as register_models_api_router,
191
192
    )

193
    register_models_api_router(app)
194

195
196
    from vllm.entrypoints.sagemaker.api_router import (
        attach_router as register_sagemaker_api_router,
197
198
    )

199
    register_sagemaker_api_router(app, supported_tasks)
200

201
    if any(task in supported_tasks for task in ("generate", "render")):
202
203
204
        from vllm.entrypoints.openai.generate.api_router import (
            register_generate_api_routers,
        )
205

206
        register_generate_api_routers(app)
207

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        from vllm.entrypoints.serve.disagg.api_router import (
            attach_router as attach_disagg_router,
        )

        attach_disagg_router(app)

        from vllm.entrypoints.serve.rlhf.api_router import (
            attach_router as attach_rlhf_router,
        )

        attach_rlhf_router(app)

        from vllm.entrypoints.serve.elastic_ep.api_router import (
            attach_router as elastic_ep_attach_router,
        )

        elastic_ep_attach_router(app)

226
    if "transcription" in supported_tasks:
227
228
        from vllm.entrypoints.openai.speech_to_text.api_router import (
            attach_router as register_speech_to_text_api_router,
229
        )
230

231
        register_speech_to_text_api_router(app)
Zhuohan Li's avatar
Zhuohan Li committed
232

233
234
235
236
237
238
239
    if "realtime" in supported_tasks:
        from vllm.entrypoints.openai.realtime.api_router import (
            attach_router as register_realtime_api_router,
        )

        register_realtime_api_router(app)

240
241
    if any(task in POOLING_TASKS for task in supported_tasks):
        from vllm.entrypoints.pooling import register_pooling_api_routers
242

243
        register_pooling_api_routers(app, supported_tasks)
244

245
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
246
247
248
249
250
251
252
253
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

254
255
    app.exception_handler(HTTPException)(http_exception_handler)
    app.exception_handler(RequestValidationError)(validation_exception_handler)
256
257
258
    app.exception_handler(EngineGenerateError)(engine_error_handler)
    app.exception_handler(EngineDeadError)(engine_error_handler)
    app.exception_handler(Exception)(exception_handler)
Ethan Xu's avatar
Ethan Xu committed
259

260
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
261
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
262
263
        from vllm.entrypoints.openai.server_utils import AuthenticationMiddleware

264
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
265

266
    if args.enable_request_id_headers:
267
268
        from vllm.entrypoints.openai.server_utils import XRequestIdMiddleware

269
        app.add_middleware(XRequestIdMiddleware)
270

271
272
273
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

274
275
276
277
278
279
280
281
    if "realtime" in supported_tasks:
        # Add WebSocket metrics middleware
        from vllm.entrypoints.openai.realtime.metrics import (
            WebSocketMetricsMiddleware,
        )

        app.add_middleware(WebSocketMetricsMiddleware)

282
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
283
284
285
286
287
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
288
        app.middleware("http")(log_response)
289

290
291
292
293
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
294
            app.add_middleware(imported)  # type: ignore[arg-type]
295
296
297
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
298
299
300
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
301

302
    app = sagemaker_standards_bootstrap(app)
Ethan Xu's avatar
Ethan Xu committed
303
304
305
    return app


306
async def init_app_state(
307
    engine_client: EngineClient,
308
    state: State,
309
    args: Namespace,
310
    supported_tasks: tuple["SupportedTask", ...] | None = None,
311
) -> None:
312
    vllm_config = engine_client.vllm_config
313
314
315
316
317
318
319
320
321
    if supported_tasks is None:
        warnings.warn(
            "The 'supported_tasks' parameter was not provided to "
            "init_app_state and will be required in a future version. "
            "Please pass 'supported_tasks' explicitly.",
            DeprecationWarning,
            stacklevel=2,
        )
        supported_tasks = _FALLBACK_SUPPORTED_TASKS
322

323
    if args.served_model_name is not None:
324
        served_model_names = args.served_model_name
325
    else:
326
        served_model_names = [args.model]
327

328
    if args.enable_log_requests:
329
        request_logger = RequestLogger(max_log_len=args.max_log_len)
330
331
    else:
        request_logger = None
332

333
    base_model_paths = [
334
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
335
336
    ]

337
    state.engine_client = engine_client
338
    state.log_stats = not args.disable_log_stats
339
    state.vllm_config = vllm_config
340
    state.args = args
341
    resolved_chat_template = load_chat_template(args.chat_template)
342

343
    # Merge default_mm_loras into the static lora_modules
344
345
346
347
348
349
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
350

351
    state.openai_serving_models = OpenAIServingModels(
352
        engine_client=engine_client,
353
        base_model_paths=base_model_paths,
354
        lora_modules=lora_modules,
355
    )
356
    await state.openai_serving_models.init_static_loras()
357
    state.openai_serving_tokenization = OpenAIServingTokenization(
358
        engine_client,
359
        state.openai_serving_models,
360
        request_logger=request_logger,
361
362
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
363
        trust_request_chat_template=args.trust_request_chat_template,
364
    )
365

366
    if any(task in supported_tasks for task in ("generate", "render")):
367
368
369
370
        from vllm.entrypoints.openai.generate.api_router import init_generate_state

        await init_generate_state(
            engine_client, state, args, request_logger, supported_tasks
371
        )
372
373

    if "transcription" in supported_tasks:
374
        from vllm.entrypoints.openai.speech_to_text.api_router import (
375
            init_transcription_state,
376
        )
377
378
379

        init_transcription_state(
            engine_client, state, args, request_logger, supported_tasks
380
        )
381

382
383
384
385
386
    if "realtime" in supported_tasks:
        from vllm.entrypoints.openai.realtime.api_router import init_realtime_state

        init_realtime_state(engine_client, state, args, request_logger, supported_tasks)

387
388
    if any(task in POOLING_TASKS for task in supported_tasks):
        from vllm.entrypoints.pooling import init_pooling_state
389

390
        init_pooling_state(engine_client, state, args, request_logger, supported_tasks)
391

392
393
394
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

395

396
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
397
398
399
400
401
402
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
403
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
404
405
406
407
408
    sock.bind(addr)

    return sock


409
410
411
412
413
414
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


415
def validate_api_server_args(args):
416
    valid_tool_parses = ToolParserManager.list_registered()
417
418
419
420
421
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
422

423
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
424
425
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
426
    ) and reasoning_parser not in valid_reasoning_parsers:
427
        raise KeyError(
428
            f"invalid reasoning parser: {reasoning_parser} "
429
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
430
        )
431

432

433
@instrument(span_name="API server setup")
434
435
436
437
def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

438
    log_version_and_model(logger, VLLM_VERSION, args.model)
439
440
441
442
443
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

444
445
446
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

447
448
    validate_api_server_args(args)

449
450
451
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
452
453
454
455
456
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
457

458
459
460
461
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

462
463
464
465
466
467
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

468
469
470
471
472
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
473
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
474
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
475
476
477
    return listen_address, sock


478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
async def build_and_serve(
    engine_client: EngineClient,
    listen_address: str,
    sock: socket.socket,
    args: Namespace,
    **uvicorn_kwargs,
) -> asyncio.Task:
    """Build FastAPI app, initialize state, and start serving.

    Returns the shutdown task for the caller to await.
    """

    # Get uvicorn log config (from file or with endpoint filter)
    log_config = get_uvicorn_log_config(args)
    if log_config is not None:
        uvicorn_kwargs["log_config"] = log_config

    supported_tasks = await engine_client.get_supported_tasks()
    logger.info("Supported tasks: %s", supported_tasks)

    app = build_app(args, supported_tasks)
    await init_app_state(engine_client, app.state, args, supported_tasks)

    logger.info("Starting vLLM server on %s", listen_address)

    return await serve_http(
        app,
        sock=sock,
        enable_ssl_refresh=args.enable_ssl_refresh,
        host=args.host,
        port=args.port,
        log_level=args.uvicorn_log_level,
        # NOTE: When the 'disable_uvicorn_access_log' value is True,
        # no access log will be output.
        access_log=not args.disable_uvicorn_access_log,
        timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
        ssl_keyfile=args.ssl_keyfile,
        ssl_certfile=args.ssl_certfile,
        ssl_ca_certs=args.ssl_ca_certs,
        ssl_cert_reqs=args.ssl_cert_reqs,
        ssl_ciphers=args.ssl_ciphers,
        h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
        h11_max_header_count=args.h11_max_header_count,
        **uvicorn_kwargs,
    )


525
526
async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
527
528

    # Add process-specific prefix to stdout and stderr.
529
    decorate_logs("APIServer")
530

531
532
533
534
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


535
536
537
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
538
539
540
541
542
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

543
544
545
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

546
    async with build_async_engine_client(
547
548
        args,
        client_config=client_config,
549
    ) as engine_client:
550
551
        shutdown_task = await build_and_serve(
            engine_client, listen_address, sock, args, **uvicorn_kwargs
552
        )
553
    # NB: Await server shutdown only after the backend context is exited
554
555
556
557
    try:
        await shutdown_task
    finally:
        sock.close()
558

Ethan Xu's avatar
Ethan Xu committed
559
560
561

if __name__ == "__main__":
    # NOTE(simon):
562
563
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
564
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
565
    parser = FlexibleArgumentParser(
566
567
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
568
569
    parser = make_arg_parser(parser)
    args = parser.parse_args()
570
    validate_parsed_serve_args(args)
571

572
    uvloop.run(run_server(args))