"vllm/benchmark_throughput.py" did not exist on "4338cc475029dcd37a291a867d52419122648e72"
api_server.py 26.2 KB
Newer Older
1
import asyncio
2
import atexit
3
4
import importlib
import inspect
5
import multiprocessing
6
import os
7
import re
8
import signal
9
import socket
10
import tempfile
11
import uuid
12
from argparse import Namespace
13
from contextlib import asynccontextmanager
14
from functools import partial
15
from http import HTTPStatus
16
from typing import AsyncIterator, Optional, Set, Tuple
17

18
import uvloop
19
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
20
21
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
22
from fastapi.responses import JSONResponse, Response, StreamingResponse
23
from starlette.datastructures import State
24
from starlette.routing import Mount
25
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
26

27
import vllm.envs as envs
28
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.engine.arg_utils import AsyncEngineArgs
30
from vllm.engine.async_llm_engine import AsyncLLMEngine  # type: ignore
31
32
33
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
34
from vllm.entrypoints.chat_utils import load_chat_template
35
from vllm.entrypoints.launcher import serve_http
36
from vllm.entrypoints.logger import RequestLogger
37
38
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
                                              validate_parsed_serve_args)
39
40
# yapf conflicts with isort for this block
# yapf: disable
41
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
42
                                              ChatCompletionResponse,
43
                                              CompletionRequest,
44
                                              CompletionResponse,
45
46
                                              DetokenizeRequest,
                                              DetokenizeResponse,
47
48
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
49
                                              LoadLoraAdapterRequest,
50
                                              ScoreRequest, ScoreResponse,
51
                                              TokenizeRequest,
52
53
54
                                              TokenizeResponse,
                                              UnloadLoraAdapterRequest)
# yapf: enable
55
56
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
57
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
58
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
59
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
60
61
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
62
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
63
from vllm.entrypoints.utils import with_cancellation
64
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
65
from vllm.usage.usage_lib import UsageContext
66
67
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
                        is_valid_ipv6_address)
68
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
69

70
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
71

72
prometheus_multiproc_dir: tempfile.TemporaryDirectory
73

74
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
75
logger = init_logger('vllm.entrypoints.openai.api_server')
76

77
_running_tasks: Set[asyncio.Task] = set()
78

79

80
@asynccontextmanager
81
async def lifespan(app: FastAPI):
82
83
    try:
        if app.state.log_stats:
84
            engine_client: EngineClient = app.state.engine_client
85
86
87

            async def _force_log():
                while True:
88
89
                    await asyncio.sleep(10.)
                    await engine_client.do_log_stats()
90
91
92
93
94
95
96
97
98
99
100
101
102
103

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
104
105


106
@asynccontextmanager
107
async def build_async_engine_client(
108
        args: Namespace) -> AsyncIterator[EngineClient]:
109

110
    # Context manager to handle engine_client lifecycle
111
112
113
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)

114
115
116
117
118
119
120
121
122
    async with build_async_engine_client_from_engine_args(
            engine_args, args.disable_frontend_multiprocessing) as engine:
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
    disable_frontend_multiprocessing: bool = False,
123
) -> AsyncIterator[EngineClient]:
124
    """
125
    Create EngineClient, either:
126
127
128
129
130
131
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

132
133
134
    # Fall back
    # TODO: fill out feature matrix.
    if (MQLLMEngineClient.is_unsupported_config(engine_args)
135
            or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
136
137
        engine_config = engine_args.create_engine_config(
            UsageContext.OPENAI_API_SERVER)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
                           "uses_ray", False)

        build_engine = partial(AsyncLLMEngine.from_engine_args,
                               engine_args=engine_args,
                               engine_config=engine_config,
                               usage_context=UsageContext.OPENAI_API_SERVER)
        if uses_ray:
            # Must run in main thread with ray for its signal handlers to work
            engine_client = build_engine()
        else:
            engine_client = await asyncio.get_running_loop().run_in_executor(
                None, build_engine)

        yield engine_client
153
154
        if hasattr(engine_client, "shutdown"):
            engine_client.shutdown()
155
156
157
158
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

174
        # Select random path for IPC.
175
        ipc_path = get_open_zmq_ipc_path()
176
177
        logger.debug("Multiprocessing frontend to use %s for IPC Path.",
                     ipc_path)
178

179
        # Start RPCServer in separate process (holds the LLMEngine).
180
181
        # the current process might have CUDA context,
        # so we need to spawn a new process
182
183
        context = multiprocessing.get_context("spawn")

184
185
186
187
        # The Process can raise an exception during startup, which may
        # not actually result in an exitcode being reported. As a result
        # we use a shared variable to communicate the information.
        engine_alive = multiprocessing.Value('b', True, lock=False)
188
189
190
        engine_process = context.Process(target=run_mp_engine,
                                         args=(engine_args,
                                               UsageContext.OPENAI_API_SERVER,
191
                                               ipc_path, engine_alive))
192
        engine_process.start()
193
        engine_pid = engine_process.pid
194
        assert engine_pid is not None, "Engine process failed to start."
195
        logger.info("Started engine process with PID %d", engine_pid)
196

197
198
199
200
201
202
203
204
        def _cleanup_ipc_path():
            socket_path = ipc_path.replace("ipc://", "")
            if os.path.exists(socket_path):
                os.remove(socket_path)

        # Ensure we clean up the local IPC socket file on exit.
        atexit.register(_cleanup_ipc_path)

205
206
        # Build RPCClient, which conforms to EngineClient Protocol.
        engine_config = engine_args.create_engine_config()
207
208
209
210
        build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
                               engine_pid)
        mq_engine_client = await asyncio.get_running_loop().run_in_executor(
            None, build_client)
211
        try:
212
213
            while True:
                try:
214
                    await mq_engine_client.setup()
215
                    break
216
                except TimeoutError:
217
218
                    if (not engine_process.is_alive()
                            or not engine_alive.value):
219
                        raise RuntimeError(
220
221
                            "Engine process failed to start. See stack "
                            "trace for the root cause.") from None
222

223
            yield mq_engine_client  # type: ignore[misc]
224
225
        finally:
            # Ensure rpc server process was terminated
226
            engine_process.terminate()
227
228

            # Close all open connections to the backend
229
            mq_engine_client.close()
230

231
232
233
234
235
            # Wait for engine process to join
            engine_process.join(4)
            if engine_process.exitcode is None:
                # Kill if taking longer than 5 seconds to stop
                engine_process.kill()
236

237
238
239
240
241
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
242
            multiprocess.mark_process_dead(engine_process.pid)
243

244

Ethan Xu's avatar
Ethan Xu committed
245
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
246

247

248
def mount_metrics(app: FastAPI):
249
250
251
252
253
254
255
256
257
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
258
259
        logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                     prometheus_multiproc_dir_path)
260
261
262
263
264
265
266
267
268
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

269
    # Workaround for 307 Redirect for /metrics
270
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
271
    app.routes.append(metrics_route)
272
273


274
275
276
277
278
279
def base(request: Request) -> OpenAIServing:
    # Reuse the existing instance
    return tokenization(request)


def chat(request: Request) -> Optional[OpenAIServingChat]:
280
281
282
    return request.app.state.openai_serving_chat


283
def completion(request: Request) -> Optional[OpenAIServingCompletion]:
284
285
286
    return request.app.state.openai_serving_completion


287
288
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
    return request.app.state.openai_serving_embedding
289
290


291
292
293
294
def score(request: Request) -> Optional[OpenAIServingScores]:
    return request.app.state.openai_serving_scores


295
296
def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization
297
298


299
def engine_client(request: Request) -> EngineClient:
300
301
302
    return request.app.state.engine_client


Ethan Xu's avatar
Ethan Xu committed
303
@router.get("/health")
304
async def health(raw_request: Request) -> Response:
305
    """Health check."""
306
    await engine_client(raw_request).check_health()
307
308
309
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
310
@router.post("/tokenize")
311
@with_cancellation
312
async def tokenize(request: TokenizeRequest, raw_request: Request):
313
314
    handler = tokenization(raw_request)

315
    generator = await handler.create_tokenize(request, raw_request)
316
317
318
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
319
    elif isinstance(generator, TokenizeResponse):
320
321
        return JSONResponse(content=generator.model_dump())

322
323
    assert_never(generator)

324

Ethan Xu's avatar
Ethan Xu committed
325
@router.post("/detokenize")
326
@with_cancellation
327
async def detokenize(request: DetokenizeRequest, raw_request: Request):
328
329
    handler = tokenization(raw_request)

330
    generator = await handler.create_detokenize(request, raw_request)
331
332
333
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
334
    elif isinstance(generator, DetokenizeResponse):
335
336
        return JSONResponse(content=generator.model_dump())

337
338
    assert_never(generator)

339

Ethan Xu's avatar
Ethan Xu committed
340
@router.get("/v1/models")
341
async def show_available_models(raw_request: Request):
342
343
344
    handler = base(raw_request)

    models = await handler.show_available_models()
345
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
346
347


Ethan Xu's avatar
Ethan Xu committed
348
@router.get("/version")
349
async def show_version():
350
    ver = {"version": VLLM_VERSION}
351
352
353
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
354
@router.post("/v1/chat/completions")
355
@with_cancellation
356
357
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
358
359
360
361
    handler = chat(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
            message="The model does not support Chat Completions API")
362

363
    generator = await handler.create_chat_completion(request, raw_request)
364

365
366
367
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
368

369
    elif isinstance(generator, ChatCompletionResponse):
370
        return JSONResponse(content=generator.model_dump())
371

372
373
    return StreamingResponse(content=generator, media_type="text/event-stream")

374

Ethan Xu's avatar
Ethan Xu committed
375
@router.post("/v1/completions")
376
@with_cancellation
377
async def create_completion(request: CompletionRequest, raw_request: Request):
378
379
380
381
382
383
    handler = completion(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
            message="The model does not support Completions API")

    generator = await handler.create_completion(request, raw_request)
384
385
386
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
387
    elif isinstance(generator, CompletionResponse):
388
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
389

390
391
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
392

Ethan Xu's avatar
Ethan Xu committed
393
@router.post("/v1/embeddings")
394
@with_cancellation
395
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
396
397
398
399
400
401
    handler = embedding(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
            message="The model does not support Embeddings API")

    generator = await handler.create_embedding(request, raw_request)
402
403
404
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
405
    elif isinstance(generator, EmbeddingResponse):
406
407
        return JSONResponse(content=generator.model_dump())

408
409
    assert_never(generator)

410

411
@router.post("/score")
412
@with_cancellation
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
async def create_score(request: ScoreRequest, raw_request: Request):
    handler = score(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
            message="The model does not support Score API")

    generator = await handler.create_score(request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    elif isinstance(generator, ScoreResponse):
        return JSONResponse(content=generator.model_dump())

    assert_never(generator)


429
@router.post("/v1/score")
430
@with_cancellation
431
432
433
434
435
436
437
438
async def create_score_v1(request: ScoreRequest, raw_request: Request):
    logger.warning(
        "To indicate that Score API is not part of standard OpenAI API, we "
        "have moved it to `/score`. Please update your client accordingly.")

    return await create_score(request, raw_request)


439
440
441
442
443
444
if envs.VLLM_TORCH_PROFILER_DIR:
    logger.warning(
        "Torch Profiler is enabled in the API server. This should ONLY be "
        "used for local development!")

    @router.post("/start_profile")
445
    async def start_profile(raw_request: Request):
446
        logger.info("Starting profiler...")
447
        await engine_client(raw_request).start_profile()
448
449
450
451
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
452
    async def stop_profile(raw_request: Request):
453
        logger.info("Stopping profiler...")
454
        await engine_client(raw_request).stop_profile()
455
456
457
458
        logger.info("Profiler stopped.")
        return Response(status_code=200)


459
460
461
462
463
464
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
    logger.warning(
        "Lora dynamic loading & unloading is enabled in the API server. "
        "This should ONLY be used for local development!")

    @router.post("/v1/load_lora_adapter")
465
466
    async def load_lora_adapter(request: LoadLoraAdapterRequest,
                                raw_request: Request):
467
468
469
470
471
472
473
        for route in [chat, completion, embedding]:
            handler = route(raw_request)
            if handler is not None:
                response = await handler.load_lora_adapter(request)
                if isinstance(response, ErrorResponse):
                    return JSONResponse(content=response.model_dump(),
                                        status_code=response.code)
474
475
476
477

        return Response(status_code=200, content=response)

    @router.post("/v1/unload_lora_adapter")
478
479
    async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
                                  raw_request: Request):
480
481
482
483
484
485
486
        for route in [chat, completion, embedding]:
            handler = route(raw_request)
            if handler is not None:
                response = await handler.unload_lora_adapter(request)
                if isinstance(response, ErrorResponse):
                    return JSONResponse(content=response.model_dump(),
                                        status_code=response.code)
487
488
489
490

        return Response(status_code=200, content=response)


491
def build_app(args: Namespace) -> FastAPI:
492
493
494
495
496
497
498
    if args.disable_fastapi_docs:
        app = FastAPI(openapi_url=None,
                      docs_url=None,
                      redoc_url=None,
                      lifespan=lifespan)
    else:
        app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
499
500
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
501

502
503
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
504
505
506
507
508
509
510
511
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
512
513
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
514
515
516
        err = ErrorResponse(message=str(exc),
                            type="BadRequestError",
                            code=HTTPStatus.BAD_REQUEST)
Ethan Xu's avatar
Ethan Xu committed
517
518
519
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

520
    if token := envs.VLLM_API_KEY or args.api_key:
521
522
523

        @app.middleware("http")
        async def authentication(request: Request, call_next):
524
525
            if request.method == "OPTIONS":
                return await call_next(request)
526
527
528
529
            url_path = request.url.path
            if app.root_path and url_path.startswith(app.root_path):
                url_path = url_path[len(app.root_path):]
            if not url_path.startswith("/v1"):
530
531
532
533
534
535
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

536
537
538
539
540
541
542
    @app.middleware("http")
    async def add_request_id(request: Request, call_next):
        request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex
        response = await call_next(request)
        response.headers["X-Request-Id"] = request_id
        return response

543
544
545
546
547
548
549
550
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
551
552
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
553

Ethan Xu's avatar
Ethan Xu committed
554
555
556
    return app


557
def init_app_state(
558
    engine_client: EngineClient,
559
560
    model_config: ModelConfig,
    state: State,
561
    args: Namespace,
562
) -> None:
563
    if args.served_model_name is not None:
564
        served_model_names = args.served_model_name
565
    else:
566
        served_model_names = [args.model]
567

568
569
570
571
572
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

573
574
575
576
577
    base_model_paths = [
        BaseModelPath(name=name, model_path=args.model)
        for name in served_model_names
    ]

578
    state.engine_client = engine_client
579
    state.log_stats = not args.disable_log_stats
Ethan Xu's avatar
Ethan Xu committed
580

581
582
583
    resolved_chat_template = load_chat_template(args.chat_template)
    logger.info("Using supplied chat template:\n%s", resolved_chat_template)

584
    state.openai_serving_chat = OpenAIServingChat(
585
        engine_client,
586
        model_config,
587
        base_model_paths,
588
589
590
591
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
592
593
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
594
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
595
        enable_auto_tools=args.enable_auto_tool_choice,
596
        tool_parser=args.tool_call_parser,
597
        enable_prompt_tokens_details=args.enable_prompt_tokens_details,
598
    ) if model_config.runner_type == "generate" else None
599
    state.openai_serving_completion = OpenAIServingCompletion(
600
        engine_client,
601
        model_config,
602
        base_model_paths,
603
604
605
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
606
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
607
    ) if model_config.runner_type == "generate" else None
608
    state.openai_serving_embedding = OpenAIServingEmbedding(
609
        engine_client,
610
        model_config,
611
        base_model_paths,
612
        request_logger=request_logger,
613
614
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
615
    ) if model_config.runner_type == "pooling" else None
616
617
618
619
620
    state.openai_serving_scores = OpenAIServingScores(
        engine_client,
        model_config,
        base_model_paths,
        request_logger=request_logger
621
    ) if (model_config.runner_type == "pooling" \
622
          and model_config.is_cross_encoder) else None
623
    state.openai_serving_tokenization = OpenAIServingTokenization(
624
        engine_client,
625
        model_config,
626
        base_model_paths,
627
628
        lora_modules=args.lora_modules,
        request_logger=request_logger,
629
630
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
631
    )
632
633


634
635
636
637
638
639
640
641
642
643
644
645
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.bind(addr)

    return sock


646
async def run_server(args, **uvicorn_kwargs) -> None:
647
648
649
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

650
651
652
653
654
655
656
657
658
    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

    valide_tool_parses = ToolParserManager.tool_parsers.keys()
    if args.enable_auto_tool_choice \
        and args.tool_call_parser not in valide_tool_parses:
        raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
                       f"(chose from {{ {','.join(valide_tool_parses)} }})")

659
660
661
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
662
663
    sock_addr = (args.host or "", args.port)
    sock = create_server_socket(sock_addr)
664

665
666
667
668
669
670
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

671
    async with build_async_engine_client(args) as engine_client:
672
673
        app = build_app(args)

674
675
        model_config = await engine_client.get_model_config()
        init_app_state(engine_client, model_config, app.state, args)
676
677
678
679
680
681
682
683
684
685
686

        shutdown_task = await serve_http(
            app,
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
687
688
689
            **uvicorn_kwargs,
        )

690
691
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
692

693
694
    sock.close()

Ethan Xu's avatar
Ethan Xu committed
695
696
697
698
699
700
701
702

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
703
    validate_parsed_serve_args(args)
704

705
    uvloop.run(run_server(args))