api_server.py 17.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import importlib
import inspect
5
import multiprocessing
6
import multiprocessing.forkserver as forkserver
7
import os
8
import signal
9
import socket
10
import tempfile
11
from argparse import Namespace
12
from collections.abc import AsyncIterator
13
from contextlib import asynccontextmanager
14
from typing import Any
15

16
import uvloop
17
from fastapi import FastAPI, HTTPException
Zhuohan Li's avatar
Zhuohan Li committed
18
19
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
20
from starlette.datastructures import State
Zhuohan Li's avatar
Zhuohan Li committed
21

22
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.engine.arg_utils import AsyncEngineArgs
24
from vllm.engine.protocol import EngineClient
25
from vllm.entrypoints.chat_utils import load_chat_template
26
from vllm.entrypoints.launcher import serve_http
27
from vllm.entrypoints.logger import RequestLogger
28
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
29
from vllm.entrypoints.openai.models.protocol import BaseModelPath
30
31
32
33
34
35
36
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.server_utils import (
    get_uvicorn_log_config,
    http_exception_handler,
    lifespan,
    log_response,
    validation_exception_handler,
37
)
38
from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap
39
40
41
42
from vllm.entrypoints.serve.elastic_ep.middleware import (
    ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
43
44
45
from vllm.entrypoints.utils import (
    cli_env_setup,
    log_non_default_args,
46
    log_version_and_model,
47
    process_lora_modules,
48
)
49
from vllm.logger import init_logger
50
from vllm.reasoning import ReasoningParserManager
51
from vllm.tasks import POOLING_TASKS, SupportedTask
52
from vllm.tool_parsers import ToolParserManager
53
from vllm.tracing import instrument
yhu422's avatar
yhu422 committed
54
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
55
from vllm.utils.argparse_utils import FlexibleArgumentParser
56
from vllm.utils.network_utils import is_valid_ipv6_address
57
from vllm.utils.system_utils import decorate_logs, set_ulimit
58
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
59

60
prometheus_multiproc_dir: tempfile.TemporaryDirectory
61

62
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
63
logger = init_logger("vllm.entrypoints.openai.api_server")
64

65

66
@asynccontextmanager
67
async def build_async_engine_client(
68
    args: Namespace,
69
70
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
71
72
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
73
) -> AsyncIterator[EngineClient]:
74
75
76
77
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
78
        multiprocessing.set_start_method("forkserver")
79
80
81
82
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

83
    # Context manager to handle engine_client lifecycle
84
85
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
86
87
88
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
89

90
    if disable_frontend_multiprocessing is None:
91
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
92

93
    async with build_async_engine_client_from_engine_args(
94
95
96
97
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
98
    ) as engine:
99
100
101
102
103
104
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
105
106
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
107
    disable_frontend_multiprocessing: bool = False,
108
    client_config: dict[str, Any] | None = None,
109
) -> AsyncIterator[EngineClient]:
110
    """
111
    Create EngineClient, either:
112
113
114
115
116
117
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

118
119
120
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

121
    if disable_frontend_multiprocessing:
122
        logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
123

124
    from vllm.v1.engine.async_llm import AsyncLLM
125

126
    async_llm: AsyncLLM | None = None
127
128
129
130
131
132

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

133
134
135
136
137
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
138
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
139
140
141
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
142
143
            client_index=client_index,
        )
144
145

        # Don't keep the dummy data in memory
146
        assert async_llm is not None
147
148
149
150
151
152
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
153
154


155
def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> FastAPI:
156
    if args.disable_fastapi_docs:
157
158
159
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
160
161
    elif args.enable_offline_docs:
        app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
162
163
    else:
        app = FastAPI(lifespan=lifespan)
164
    app.state.args = args
165
166
167
168

    from vllm.entrypoints.openai.basic.api_router import register_basic_api_routers

    register_basic_api_routers(app)
169

170
    from vllm.entrypoints.serve import register_vllm_serve_api_routers
171

172
    register_vllm_serve_api_routers(app)
173

174
175
    from vllm.entrypoints.openai.models.api_router import (
        attach_router as register_models_api_router,
176
177
    )

178
    register_models_api_router(app)
179

180
181
    from vllm.entrypoints.sagemaker.api_router import (
        attach_router as register_sagemaker_api_router,
182
183
    )

184
    register_sagemaker_api_router(app, supported_tasks)
185

186
187
188
189
    if "generate" in supported_tasks:
        from vllm.entrypoints.openai.generate.api_router import (
            register_generate_api_routers,
        )
190

191
        register_generate_api_routers(app)
192

193
    if "transcription" in supported_tasks:
194
195
        from vllm.entrypoints.openai.speech_to_text.api_router import (
            attach_router as register_speech_to_text_api_router,
196
        )
197

198
        register_speech_to_text_api_router(app)
Zhuohan Li's avatar
Zhuohan Li committed
199

200
201
202
203
204
205
206
    if "realtime" in supported_tasks:
        from vllm.entrypoints.openai.realtime.api_router import (
            attach_router as register_realtime_api_router,
        )

        register_realtime_api_router(app)

207
208
    if any(task in POOLING_TASKS for task in supported_tasks):
        from vllm.entrypoints.pooling import register_pooling_api_routers
209

210
        register_pooling_api_routers(app, supported_tasks)
211

212
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
213
214
215
216
217
218
219
220
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

221
222
    app.exception_handler(HTTPException)(http_exception_handler)
    app.exception_handler(RequestValidationError)(validation_exception_handler)
Ethan Xu's avatar
Ethan Xu committed
223

224
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
225
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
226
227
        from vllm.entrypoints.openai.server_utils import AuthenticationMiddleware

228
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
229

230
    if args.enable_request_id_headers:
231
232
        from vllm.entrypoints.openai.server_utils import XRequestIdMiddleware

233
        app.add_middleware(XRequestIdMiddleware)
234

235
236
237
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

238
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
239
240
241
242
243
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
244
        app.middleware("http")(log_response)
245

246
247
248
249
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
250
            app.add_middleware(imported)  # type: ignore[arg-type]
251
252
253
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
254
255
256
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
257

258
    app = sagemaker_standards_bootstrap(app)
Ethan Xu's avatar
Ethan Xu committed
259
260
261
    return app


262
async def init_app_state(
263
    engine_client: EngineClient,
264
    state: State,
265
    args: Namespace,
266
    supported_tasks: tuple["SupportedTask", ...],
267
) -> None:
268
269
    vllm_config = engine_client.vllm_config

270
    if args.served_model_name is not None:
271
        served_model_names = args.served_model_name
272
    else:
273
        served_model_names = [args.model]
274

275
    if args.enable_log_requests:
276
        request_logger = RequestLogger(max_log_len=args.max_log_len)
277
278
    else:
        request_logger = None
279

280
    base_model_paths = [
281
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
282
283
    ]

284
    state.engine_client = engine_client
285
    state.log_stats = not args.disable_log_stats
286
    state.vllm_config = vllm_config
287
    state.args = args
288
    resolved_chat_template = load_chat_template(args.chat_template)
289

290
    # Merge default_mm_loras into the static lora_modules
291
292
293
294
295
296
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
297

298
    state.openai_serving_models = OpenAIServingModels(
299
        engine_client=engine_client,
300
        base_model_paths=base_model_paths,
301
        lora_modules=lora_modules,
302
    )
303
    await state.openai_serving_models.init_static_loras()
304
    state.openai_serving_tokenization = OpenAIServingTokenization(
305
        engine_client,
306
        state.openai_serving_models,
307
        request_logger=request_logger,
308
309
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
310
        trust_request_chat_template=args.trust_request_chat_template,
311
        log_error_stack=args.log_error_stack,
312
    )
313
314
315
316
317
318

    if "generate" in supported_tasks:
        from vllm.entrypoints.openai.generate.api_router import init_generate_state

        await init_generate_state(
            engine_client, state, args, request_logger, supported_tasks
319
        )
320
321

    if "transcription" in supported_tasks:
322
        from vllm.entrypoints.openai.speech_to_text.api_router import (
323
            init_transcription_state,
324
        )
325
326
327

        init_transcription_state(
            engine_client, state, args, request_logger, supported_tasks
328
        )
329

330
331
332
333
334
    if "realtime" in supported_tasks:
        from vllm.entrypoints.openai.realtime.api_router import init_realtime_state

        init_realtime_state(engine_client, state, args, request_logger, supported_tasks)

335
336
    if any(task in POOLING_TASKS for task in supported_tasks):
        from vllm.entrypoints.pooling import init_pooling_state
337

338
        init_pooling_state(engine_client, state, args, request_logger, supported_tasks)
339

340
341
342
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

343

344
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
345
346
347
348
349
350
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
351
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
352
353
354
355
356
    sock.bind(addr)

    return sock


357
358
359
360
361
362
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


363
def validate_api_server_args(args):
364
    valid_tool_parses = ToolParserManager.list_registered()
365
366
367
368
369
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
370

371
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
372
373
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
374
    ) and reasoning_parser not in valid_reasoning_parsers:
375
        raise KeyError(
376
            f"invalid reasoning parser: {reasoning_parser} "
377
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
378
        )
379

380

381
@instrument(span_name="API server setup")
382
383
384
385
def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

386
    log_version_and_model(logger, VLLM_VERSION, args.model)
387
388
389
390
391
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

392
393
394
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

395
396
    validate_api_server_args(args)

397
398
399
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
400
401
402
403
404
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
405

406
407
408
409
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

410
411
412
413
414
415
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

416
417
418
419
420
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
421
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
422
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
423
424
425
426
427
    return listen_address, sock


async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
428
429

    # Add process-specific prefix to stdout and stderr.
430
    decorate_logs("APIServer")
431

432
433
434
435
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


436
437
438
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
439
440
441
442
443
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

444
445
446
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

447
448
    # Get uvicorn log config (from file or with endpoint filter)
    log_config = get_uvicorn_log_config(args)
449
    if log_config is not None:
450
        uvicorn_kwargs["log_config"] = log_config
451

452
    async with build_async_engine_client(
453
454
        args,
        client_config=client_config,
455
    ) as engine_client:
456
457
        supported_tasks = await engine_client.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
458

459
460
        app = build_app(args, supported_tasks)
        await init_app_state(engine_client, app.state, args, supported_tasks)
461

462
463
        logger.info(
            "Starting vLLM API server %d on %s",
464
            engine_client.vllm_config.parallel_config._api_process_rank,
465
466
            listen_address,
        )
467
468
        shutdown_task = await serve_http(
            app,
469
            sock=sock,
470
            enable_ssl_refresh=args.enable_ssl_refresh,
471
472
473
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
474
475
476
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
            # no access log will be output.
            access_log=not args.disable_uvicorn_access_log,
477
            timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
478
479
480
481
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
482
            ssl_ciphers=args.ssl_ciphers,
483
484
            h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
            h11_max_header_count=args.h11_max_header_count,
485
486
487
            **uvicorn_kwargs,
        )

488
    # NB: Await server shutdown only after the backend context is exited
489
490
491
492
    try:
        await shutdown_task
    finally:
        sock.close()
493

Ethan Xu's avatar
Ethan Xu committed
494
495
496

if __name__ == "__main__":
    # NOTE(simon):
497
498
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
499
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
500
    parser = FlexibleArgumentParser(
501
502
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
503
504
    parser = make_arg_parser(parser)
    args = parser.parse_args()
505
    validate_parsed_serve_args(args)
506

507
    uvloop.run(run_server(args))