run_batch.py 31.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import base64
6
import tempfile
7
from argparse import Namespace
8
from collections.abc import Awaitable, Callable
9
from http import HTTPStatus
10
from io import BytesIO, StringIO
11
from typing import Any, TypeAlias
12
from urllib.parse import urlparse
13
14

import aiohttp
15
import torch
16
from fastapi import UploadFile
17
from prometheus_client import start_http_server
18
from pydantic import Field, TypeAdapter, field_validator, model_validator
19
from pydantic_core.core_schema import ValidationInfo
20
from tqdm import tqdm
21

22
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
23
from vllm.engine.protocol import EngineClient
24
from vllm.entrypoints.logger import RequestLogger
25
from vllm.entrypoints.openai.chat_completion.protocol import (
26
    ChatCompletionRequest,
27
    ChatCompletionResponse,
28
29
30
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
31
    ErrorInfo,
32
    ErrorResponse,
33
    OpenAIBaseModel,
34
)
35
36
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from vllm.entrypoints.openai.speech_to_text.protocol import (
    TranscriptionRequest,
    TranscriptionResponse,
    TranscriptionResponseVerbose,
    TranslationRequest,
    TranslationResponse,
    TranslationResponseVerbose,
)
from vllm.entrypoints.openai.speech_to_text.serving import (
    OpenAIServingTranscription,
    OpenAIServingTranslation,
)
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingRequest,
    EmbeddingResponse,
)
53
54
55
56
57
58
59
60
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
    RerankResponse,
    ScoreRequest,
    ScoreResponse,
)
from vllm.entrypoints.pooling.score.serving import ServingScores
61
from vllm.logger import init_logger
62
from vllm.reasoning import ReasoningParserManager
63
from vllm.tasks import SupportedTask
64
65
from vllm.utils import random_uuid
from vllm.utils.argparse_utils import FlexibleArgumentParser
66
from vllm.version import __version__ as VLLM_VERSION
67

68
69
logger = init_logger(__name__)

70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class BatchTranscriptionRequest(TranscriptionRequest):
    """
    Batch transcription request that uses file_url instead of file.

    This class extends TranscriptionRequest but replaces the file field
    with file_url to support batch processing from audio files written in JSON format.
    """

    file_url: str = Field(
        ...,
        description=(
            "Either a URL of the audio or a data URL with base64 encoded audio data. "
        ),
    )

    # Override file to be optional and unused for batch processing
    file: UploadFile | None = Field(default=None, exclude=True)  # type: ignore[assignment]

    @model_validator(mode="before")
    @classmethod
    def validate_no_file(cls, data: Any):
        """Ensure file field is not provided in batch requests."""
        if isinstance(data, dict) and "file" in data:
            raise ValueError(
                "The 'file' field is not supported in batch requests. "
                "Use 'file_url' instead."
            )
        return data


class BatchTranslationRequest(TranslationRequest):
    """
    Batch translation request that uses file_url instead of file.

    This class extends TranslationRequest but replaces the file field
    with file_url to support batch processing from audio files written in JSON format.
    """

    file_url: str = Field(
        ...,
        description=(
            "Either a URL of the audio or a data URL with base64 encoded audio data. "
        ),
    )

    # Override file to be optional and unused for batch processing
    file: UploadFile | None = Field(default=None, exclude=True)  # type: ignore[assignment]

    @model_validator(mode="before")
    @classmethod
    def validate_no_file(cls, data: Any):
        """Ensure file field is not provided in batch requests."""
        if isinstance(data, dict) and "file" in data:
            raise ValueError(
                "The 'file' field is not supported in batch requests. "
                "Use 'file_url' instead."
            )
        return data


131
BatchRequestInputBody: TypeAlias = (
132
133
134
135
136
137
    ChatCompletionRequest
    | EmbeddingRequest
    | ScoreRequest
    | RerankRequest
    | BatchTranscriptionRequest
    | BatchTranslationRequest
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
)


class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

    # The parameters of the request.
    body: BatchRequestInputBody

    @field_validator("body", mode="plain")
    @classmethod
    def check_type_for_url(cls, value: Any, info: ValidationInfo):
        # Use url to disambiguate models
        url: str = info.data["url"]
        if url == "/v1/chat/completions":
            return ChatCompletionRequest.model_validate(value)
        if url == "/v1/embeddings":
            return TypeAdapter(EmbeddingRequest).validate_python(value)
        if url.endswith("/score"):
173
            return TypeAdapter(ScoreRequest).validate_python(value)
174
175
        if url.endswith("/rerank"):
            return RerankRequest.model_validate(value)
176
177
178
179
        if url == "/v1/audio/transcriptions":
            return BatchTranscriptionRequest.model_validate(value)
        if url == "/v1/audio/translations":
            return BatchTranslationRequest.model_validate(value)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        return TypeAdapter(BatchRequestInputBody).validate_python(value)


class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
    body: (
        ChatCompletionResponse
        | EmbeddingResponse
        | ScoreResponse
        | RerankResponse
196
197
198
199
        | TranscriptionResponse
        | TranscriptionResponseVerbose
        | TranslationResponse
        | TranslationResponseVerbose
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        | None
    ) = None


class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

    response: BatchResponseData | None

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
    error: Any | None


222
def make_arg_parser(parser: FlexibleArgumentParser):
223
224
225
226
227
    parser.add_argument(
        "-i",
        "--input-file",
        required=True,
        type=str,
228
        help="The path or url to a single input file. Currently supports local file "
229
        "paths, or the http protocol (http or https). If a URL is specified, "
230
231
        "the file should be available via HTTP GET.",
    )
232
233
234
235
236
237
238
    parser.add_argument(
        "-o",
        "--output-file",
        required=True,
        type=str,
        help="The path or url to a single output file. Currently supports "
        "local file paths, or web (http or https) urls. If a URL is specified,"
239
240
        " the file should be available via HTTP PUT.",
    )
241
242
243
244
245
246
247
    parser.add_argument(
        "--output-tmp-dir",
        type=str,
        default=None,
        help="The directory to store the output file before uploading it "
        "to the output URL.",
    )
248
249
250
251
252
253
    parser.add_argument(
        "--response-role",
        type=optional_type(str),
        default="assistant",
        help="The role name to return if `request.add_generation_prompt=True`.",
    )
254
255

    parser = AsyncEngineArgs.add_cli_args(parser)
256

257
258
259
260
261
262
263
264
    parser.add_argument(
        "--max-log-len",
        type=int,
        default=None,
        help="Max number of prompt characters or prompt "
        "ID numbers being printed in log."
        "\n\nDefault: Unlimited",
    )
265

266
267
268
    parser.add_argument(
        "--enable-metrics", action="store_true", help="Enable Prometheus metrics"
    )
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    parser.add_argument(
        "--url",
        type=str,
        default="0.0.0.0",
        help="URL to the Prometheus metrics server "
        "(only needed if enable-metrics is set).",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="Port number for the Prometheus metrics server "
        "(only needed if enable-metrics is set).",
    )
283
284
    parser.add_argument(
        "--enable-prompt-tokens-details",
285
        action="store_true",
286
        default=False,
287
288
        help="If set to True, enable prompt_tokens_details in usage.",
    )
289
290
291
292
293
294
295
    parser.add_argument(
        "--enable-force-include-usage",
        action="store_true",
        default=False,
        help="If set to True, include usage on every request "
        "(even when stream_options is not specified)",
    )
296

297
298
299
300
    return parser


def parse_args():
301
    parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.")
302
    return make_arg_parser(parser).parse_args()
303
304


305
306
307
308
309
310
311
312
313
314
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n"  # noqa: E501


class BatchProgressTracker:
    def __init__(self):
        self._total = 0
315
        self._pbar: tqdm | None = None
316
317
318
319
320
321
322
323
324

    def submitted(self):
        self._total += 1

    def completed(self):
        if self._pbar:
            self._pbar.update()

    def pbar(self) -> tqdm:
325
326
327
328
329
330
331
332
333
334
335
        enable_tqdm = (
            not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
        )
        self._pbar = tqdm(
            total=self._total,
            unit="req",
            desc="Running batch",
            mininterval=5,
            disable=not enable_tqdm,
            bar_format=_BAR_FORMAT,
        )
336
337
338
        return self._pbar


339
340
async def read_file(path_or_url: str) -> str:
    if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
341
        async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp:
342
343
            return await resp.text()
    else:
344
        with open(path_or_url, encoding="utf-8") as f:
345
346
347
            return f.read()


348
349
350
async def write_local_file(
    output_path: str, batch_outputs: list[BatchRequestOutput]
) -> None:
351
352
353
354
355
356
    """
    Write the responses to a local file.
    output_path: The path to write the responses to.
    batch_outputs: The list of batch outputs to write.
    """
    # We should make this async, but as long as run_batch runs as a
357
    # standalone program, blocking the event loop won't affect performance.
358
359
360
361
362
    with open(output_path, "w", encoding="utf-8") as f:
        for o in batch_outputs:
            print(o.model_dump_json(), file=f)


363
async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None:
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    """
    Upload a local file to a URL.
    output_url: The URL to upload the file to.
    data_or_file: Either the data to upload or the path to the file to upload.
    from_file: If True, data_or_file is the path to the file to upload.
    """
    # Timeout is a common issue when uploading large files.
    # We retry max_retries times before giving up.
    max_retries = 5
    # Number of seconds to wait before retrying.
    delay = 5

    for attempt in range(1, max_retries + 1):
        try:
            # We increase the timeout to 1000 seconds to allow
            # for large files (default is 300).
380
381
382
            async with aiohttp.ClientSession(
                timeout=aiohttp.ClientTimeout(total=1000)
            ) as session:
383
384
                if from_file:
                    with open(data_or_file, "rb") as file:
385
                        async with session.put(output_url, data=file) as response:
386
                            if response.status != 200:
387
388
389
390
391
                                raise Exception(
                                    f"Failed to upload file.\n"
                                    f"Status: {response.status}\n"
                                    f"Response: {response.text()}"
                                )
392
                else:
393
                    async with session.put(output_url, data=data_or_file) as response:
394
                        if response.status != 200:
395
396
397
398
399
                            raise Exception(
                                f"Failed to upload data.\n"
                                f"Status: {response.status}\n"
                                f"Response: {response.text()}"
                            )
400
401
402
403

        except Exception as e:
            if attempt < max_retries:
                logger.error(
404
405
406
407
                    "Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...",  # noqa: E501
                    attempt,
                    e,
                    delay,
408
409
410
                )
                await asyncio.sleep(delay)
            else:
411
412
413
                raise Exception(
                    f"Failed to upload data (attempt {attempt}). Error message: {str(e)}."  # noqa: E501
                ) from e
414
415


416
417
418
async def write_file(
    path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str
) -> None:
419
420
421
422
423
424
425
    """
    Write batch_outputs to a file or upload to a URL.
    path_or_url: The path or URL to write batch_outputs to.
    batch_outputs: The list of batch outputs to write.
    output_tmp_dir: The directory to store the output file before uploading it
    to the output URL.
    """
426
    if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        if output_tmp_dir is None:
            logger.info("Writing outputs to memory buffer")
            output_buffer = StringIO()
            for o in batch_outputs:
                print(o.model_dump_json(), file=output_buffer)
            output_buffer.seek(0)
            logger.info("Uploading outputs to %s", path_or_url)
            await upload_data(
                path_or_url,
                output_buffer.read().strip().encode("utf-8"),
                from_file=False,
            )
        else:
            # Write responses to a temporary file and then upload it to the URL.
            with tempfile.NamedTemporaryFile(
442
443
444
445
446
                mode="w",
                encoding="utf-8",
                dir=output_tmp_dir,
                prefix="tmp_batch_output_",
                suffix=".jsonl",
447
            ) as f:
448
                logger.info("Writing outputs to temporary local file %s", f.name)
449
450
451
                await write_local_file(f.name, batch_outputs)
                logger.info("Uploading outputs to %s", path_or_url)
                await upload_data(path_or_url, f.name, from_file=True)
452
    else:
453
454
        logger.info("Writing outputs to local file %s", path_or_url)
        await write_local_file(path_or_url, batch_outputs)
455
456


457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
async def download_bytes_from_url(url: str) -> bytes:
    """
    Download data from a URL or decode from a data URL.

    Args:
        url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...)

    Returns:
        Data as bytes
    """
    parsed = urlparse(url)

    # Handle data URLs (base64 encoded)
    if parsed.scheme == "data":
        # Format: data:...;base64,<base64_data>
        if "," in url:
            header, data = url.split(",", 1)
            if "base64" in header:
                return base64.b64decode(data)
            else:
                raise ValueError(f"Unsupported data URL encoding: {header}")
        else:
            raise ValueError(f"Invalid data URL format: {url}")

    # Handle HTTP/HTTPS URLs
    elif parsed.scheme in ("http", "https"):
        async with (
            aiohttp.ClientSession() as session,
            session.get(url) as resp,
        ):
            if resp.status != 200:
                raise Exception(
                    f"Failed to download data from URL: {url}. Status: {resp.status}"
                )
            return await resp.read()

    else:
        raise ValueError(
            f"Unsupported URL scheme: {parsed.scheme}. "
            "Supported schemes: http, https, data"
        )


500
501
502
def make_error_request_output(
    request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
503
504
505
506
507
508
509
510
511
512
513
514
515
    batch_output = BatchRequestOutput(
        id=f"vllm-{random_uuid()}",
        custom_id=request.custom_id,
        response=BatchResponseData(
            status_code=HTTPStatus.BAD_REQUEST,
            request_id=f"vllm-batch-{random_uuid()}",
        ),
        error=error_msg,
    )
    return batch_output


async def make_async_error_request_output(
516
517
    request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
518
519
520
    return make_error_request_output(request, error_msg)


521
522
523
524
525
async def run_request(
    serving_engine_func: Callable,
    request: BatchRequestInput,
    tracker: BatchProgressTracker,
) -> BatchRequestOutput:
526
    response = await serving_engine_func(request.body)
527

528
    if isinstance(
529
        response,
530
531
532
533
534
535
536
537
538
539
        (
            ChatCompletionResponse,
            EmbeddingResponse,
            ScoreResponse,
            RerankResponse,
            TranscriptionResponse,
            TranscriptionResponseVerbose,
            TranslationResponse,
            TranslationResponseVerbose,
        ),
540
    ):
541
542
543
        batch_output = BatchRequestOutput(
            id=f"vllm-{random_uuid()}",
            custom_id=request.custom_id,
544
            response=BatchResponseData(
545
546
                body=response, request_id=f"vllm-batch-{random_uuid()}"
            ),
547
548
            error=None,
        )
549
    elif isinstance(response, ErrorResponse):
550
551
552
        batch_output = BatchRequestOutput(
            id=f"vllm-{random_uuid()}",
            custom_id=request.custom_id,
553
            response=BatchResponseData(
554
                status_code=response.error.code,
555
556
                request_id=f"vllm-batch-{random_uuid()}",
            ),
557
            error=response,
558
        )
559
    else:
560
        batch_output = make_error_request_output(
561
562
            request, error_msg="Request must not be sent in stream mode"
        )
563

564
    tracker.completed()
565
566
567
    return batch_output


568
569
570
WrapperFn: TypeAlias = Callable[[Callable], Callable]


571
572
573
574
575
def handle_endpoint_request(
    request: BatchRequestInput,
    tracker: BatchProgressTracker,
    url_matcher: Callable[[str], bool],
    handler_getter: Callable[[], Callable | None],
576
    wrapper_fn: WrapperFn | None = None,
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
) -> Awaitable[BatchRequestOutput] | None:
    """
    Generic handler for endpoint requests.

    Args:
        request: The batch request input
        tracker: Progress tracker for the batch
        url_matcher: Function that takes a URL and returns True if it matches
        handler_getter: Function that returns the handler function or None
        wrapper_fn: Optional function to wrap the handler (e.g., for transcriptions)

    Returns:
        Awaitable[BatchRequestOutput] if the request was handled,
        None if URL didn't match
    """
    if not url_matcher(request.url):
        return None
594

595
596
597
598
    handler_fn = handler_getter()
    if handler_fn is None:
        error_msg = f"Model does not support endpoint: {request.url}"
        return make_async_error_request_output(request, error_msg=error_msg)
599

600
601
602
603
604
605
606
607
    # Apply wrapper if provided (e.g., for transcriptions/translations)
    if wrapper_fn is not None:
        handler_fn = wrapper_fn(handler_fn)

    tracker.submitted()
    return run_request(handler_fn, request, tracker)


608
def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
    """
    Factory function to create a wrapper for transcription/translation handlers.
    The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest
    to TranscriptionRequest or TranslationRequest and calls the appropriate handler.

    Args:
        is_translation: If True, process as translation; otherwise process
            as transcription

    Returns:
        A function that takes a handler and returns a wrapped handler
    """

    def wrapper(handler_fn: Callable):
        async def transcription_wrapper(
            batch_request_body: (BatchTranscriptionRequest | BatchTranslationRequest),
        ) -> (
            TranscriptionResponse
            | TranscriptionResponseVerbose
            | TranslationResponse
            | TranslationResponseVerbose
            | ErrorResponse
        ):
            try:
                # Download data from URL
                audio_data = await download_bytes_from_url(batch_request_body.file_url)

                # Create a mock file from the downloaded audio data
                mock_file = UploadFile(
                    file=BytesIO(audio_data),
                    filename="audio.bin",
                )

                # Convert batch request to regular request
                # by copying all fields except file_url and setting file to mock_file
                request_dict = batch_request_body.model_dump(exclude={"file_url"})
                request_dict["file"] = mock_file

                if is_translation:
                    # Create TranslationRequest from BatchTranslationRequest
                    translation_request = TranslationRequest.model_validate(
                        request_dict
                    )
                    return await handler_fn(audio_data, translation_request)
                else:
                    # Create TranscriptionRequest from BatchTranscriptionRequest
                    transcription_request = TranscriptionRequest.model_validate(
                        request_dict
                    )
                    return await handler_fn(audio_data, transcription_request)
            except Exception as e:
                operation = "translation" if is_translation else "transcription"
                return ErrorResponse(
                    error=ErrorInfo(
                        message=f"Failed to process {operation}: {str(e)}",
                        type="BadRequestError",
                        code=HTTPStatus.BAD_REQUEST.value,
                    )
                )

        return transcription_wrapper

    return wrapper


def build_endpoint_registry(
675
676
    engine_client: EngineClient,
    args: Namespace,
677
678
679
680
681
682
    base_model_paths: list[BaseModelPath],
    request_logger: RequestLogger | None,
    supported_tasks: tuple[SupportedTask, ...],
) -> dict[str, dict[str, Any]]:
    """
    Build the endpoint registry with all serving objects and handler configurations.
683

684
685
686
687
688
689
    Args:
        engine_client: The engine client
        args: Command line arguments
        base_model_paths: List of base model paths
        request_logger: Optional request logger
        supported_tasks: Tuple of supported tasks
690

691
692
693
    Returns:
        Dictionary mapping endpoint keys to their configurations
    """
694
    model_config = engine_client.model_config
695

696
    # Create the openai serving objects.
697
    openai_serving_models = OpenAIServingModels(
698
        engine_client=engine_client,
699
700
701
        base_model_paths=base_model_paths,
        lora_modules=None,
    )
702

703
704
705
706
707
708
709
710
    openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=None,
            chat_template_content_format="auto",
711
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
712
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
713
            enable_force_include_usage=args.enable_force_include_usage,
714
715
716
            default_chat_template_kwargs=getattr(
                args, "default_chat_template_kwargs", None
            ),
717
718
719
720
        )
        if "generate" in supported_tasks
        else None
    )
721

722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
    openai_serving_embedding = (
        OpenAIServingEmbedding(
            engine_client,
            openai_serving_models,
            request_logger=request_logger,
            chat_template=None,
            chat_template_content_format="auto",
        )
        if "embed" in supported_tasks
        else None
    )

    enable_serving_reranking = (
        "classify" in supported_tasks
        and getattr(model_config.hf_config, "num_labels", 0) == 1
    )

    openai_serving_scores = (
        ServingScores(
            engine_client,
            openai_serving_models,
            request_logger=request_logger,
744
            score_template=None,
745
746
747
748
        )
        if ("embed" in supported_tasks or enable_serving_reranking)
        else None
    )
749

750
751
752
753
754
755
756
757
758
759
    openai_serving_transcription = (
        OpenAIServingTranscription(
            engine_client,
            openai_serving_models,
            request_logger=request_logger,
            enable_force_include_usage=args.enable_force_include_usage,
        )
        if "transcription" in supported_tasks
        else None
    )
760

761
762
763
764
765
766
767
768
769
770
    openai_serving_translation = (
        OpenAIServingTranslation(
            engine_client,
            openai_serving_models,
            request_logger=request_logger,
            enable_force_include_usage=args.enable_force_include_usage,
        )
        if "transcription" in supported_tasks
        else None
    )
771

772
773
774
775
776
    # Registry of endpoint configurations
    endpoint_registry: dict[str, dict[str, Any]] = {
        "completions": {
            "url_matcher": lambda url: url == "/v1/chat/completions",
            "handler_getter": lambda: (
777
778
779
                openai_serving_chat.create_chat_completion
                if openai_serving_chat is not None
                else None
780
781
782
783
784
785
            ),
            "wrapper_fn": None,
        },
        "embeddings": {
            "url_matcher": lambda url: url == "/v1/embeddings",
            "handler_getter": lambda: (
786
787
788
                openai_serving_embedding.create_embedding
                if openai_serving_embedding is not None
                else None
789
790
791
792
793
794
            ),
            "wrapper_fn": None,
        },
        "score": {
            "url_matcher": lambda url: url.endswith("/score"),
            "handler_getter": lambda: (
795
796
797
                openai_serving_scores.create_score
                if openai_serving_scores is not None
                else None
798
799
800
801
802
803
            ),
            "wrapper_fn": None,
        },
        "rerank": {
            "url_matcher": lambda url: url.endswith("/rerank"),
            "handler_getter": lambda: (
804
805
806
                openai_serving_scores.do_rerank
                if openai_serving_scores is not None
                else None
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
            ),
            "wrapper_fn": None,
        },
        "transcriptions": {
            "url_matcher": lambda url: url == "/v1/audio/transcriptions",
            "handler_getter": lambda: (
                openai_serving_transcription.create_transcription
                if openai_serving_transcription is not None
                else None
            ),
            "wrapper_fn": make_transcription_wrapper(is_translation=False),
        },
        "translations": {
            "url_matcher": lambda url: url == "/v1/audio/translations",
            "handler_getter": lambda: (
                openai_serving_translation.create_translation
                if openai_serving_translation is not None
                else None
            ),
            "wrapper_fn": make_transcription_wrapper(is_translation=True),
        },
    }

    return endpoint_registry


def validate_run_batch_args(args):
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
    ) and reasoning_parser not in valid_reasoning_parsers:
        raise KeyError(
            f"invalid reasoning parser: {reasoning_parser} "
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
        )


async def run_batch(
    engine_client: EngineClient,
    args: Namespace,
) -> None:
    if args.served_model_name is not None:
        served_model_names = args.served_model_name
    else:
        served_model_names = [args.model]

    if args.enable_log_requests:
        request_logger = RequestLogger(max_log_len=args.max_log_len)
    else:
        request_logger = None

    base_model_paths = [
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
    ]

    supported_tasks = await engine_client.get_supported_tasks()
    logger.info("Supported tasks: %s", supported_tasks)

    endpoint_registry = build_endpoint_registry(
        engine_client=engine_client,
        args=args,
        base_model_paths=base_model_paths,
        request_logger=request_logger,
        supported_tasks=supported_tasks,
    )

    tracker = BatchProgressTracker()
    logger.info("Reading batch from %s...", args.input_file)

    # Submit all requests in the file to the engine "concurrently".
    response_futures: list[Awaitable[BatchRequestOutput]] = []
    for request_json in (await read_file(args.input_file)).strip().split("\n"):
        # Skip empty lines.
        request_json = request_json.strip()
        if not request_json:
            continue

        request = BatchRequestInput.model_validate_json(request_json)

        # Use the last segment of the URL as the endpoint key.
        # More advanced URL matching is done in url_matcher of endpoint_registry.
        endpoint_key = request.url.split("/")[-1]

        result = None
        if endpoint_key in endpoint_registry:
            endpoint_config = endpoint_registry[endpoint_key]
            result = handle_endpoint_request(
                request,
                tracker,
                url_matcher=endpoint_config["url_matcher"],
                handler_getter=endpoint_config["handler_getter"],
                wrapper_fn=endpoint_config["wrapper_fn"],
899
            )
900

901
902
        if result is not None:
            response_futures.append(result)
903
        else:
904
905
906
            response_futures.append(
                make_async_error_request_output(
                    request,
907
908
                    error_msg=f"URL {request.url} was used. "
                    "Supported endpoints: /v1/chat/completions, /v1/embeddings,"
909
910
911
                    " /v1/audio/transcriptions, /v1/audio/translations, /score, "
                    " /rerank. See vllm/entrypoints/openai/api_server.py "
                    "for supported score/rerank versions.",
912
913
                )
            )
914

915
916
    with tracker.pbar():
        responses = await asyncio.gather(*response_futures)
917

918
    await write_file(args.output_file, responses, args.output_tmp_dir)
919
920


921
async def main(args: Namespace):
922
923
924
    from vllm.entrypoints.openai.api_server import build_async_engine_client
    from vllm.usage.usage_lib import UsageContext

925
926
    validate_run_batch_args(args)

927
    async with build_async_engine_client(
928
929
930
        args,
        usage_context=UsageContext.OPENAI_BATCH_RUNNER,
        disable_frontend_multiprocessing=False,
931
    ) as engine_client:
932
        await run_batch(engine_client, args)
933
934


935
936
937
if __name__ == "__main__":
    args = parse_args()

938
    logger.info("vLLM batch processing API version %s", VLLM_VERSION)
939
940
    logger.info("args: %s", args)

941
942
943
944
945
946
947
948
    # Start the Prometheus metrics server. LLMEngine uses the Prometheus client
    # to publish metrics at the /metrics endpoint.
    if args.enable_metrics:
        logger.info("Prometheus metrics enabled")
        start_http_server(port=args.port, addr=args.url)
    else:
        logger.info("Prometheus metrics disabled")

949
    asyncio.run(main(args))