run_batch.py 21.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import tempfile
6
from argparse import Namespace
7
from collections.abc import Awaitable, Callable
8
from http import HTTPStatus
9
from io import StringIO
10
from typing import Any, TypeAlias
11
12

import aiohttp
13
import torch
14
from prometheus_client import start_http_server
15
16
from pydantic import TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
17
from tqdm import tqdm
18

19
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
20
from vllm.engine.protocol import EngineClient
21
from vllm.entrypoints.logger import RequestLogger
22
from vllm.entrypoints.openai.chat_completion.protocol import (
23
    ChatCompletionRequest,
24
    ChatCompletionResponse,
25
26
27
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
28
    ErrorResponse,
29
    OpenAIBaseModel,
30
)
31
32
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
33
34
35
36
37
38
39
40
41
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest, EmbeddingResponse
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
    RerankResponse,
    ScoreRequest,
    ScoreResponse,
)
from vllm.entrypoints.pooling.score.serving import ServingScores
42
from vllm.logger import init_logger
43
from vllm.reasoning import ReasoningParserManager
44
45
from vllm.utils import random_uuid
from vllm.utils.argparse_utils import FlexibleArgumentParser
46
from vllm.version import __version__ as VLLM_VERSION
47

48
49
logger = init_logger(__name__)

50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
BatchRequestInputBody: TypeAlias = (
    ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
)


class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

    # The parameters of the request.
    body: BatchRequestInputBody

    @field_validator("body", mode="plain")
    @classmethod
    def check_type_for_url(cls, value: Any, info: ValidationInfo):
        # Use url to disambiguate models
        url: str = info.data["url"]
        if url == "/v1/chat/completions":
            return ChatCompletionRequest.model_validate(value)
        if url == "/v1/embeddings":
            return TypeAdapter(EmbeddingRequest).validate_python(value)
        if url.endswith("/score"):
            return ScoreRequest.model_validate(value)
        if url.endswith("/rerank"):
            return RerankRequest.model_validate(value)
        return TypeAdapter(BatchRequestInputBody).validate_python(value)


class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
    body: (
        ChatCompletionResponse
        | EmbeddingResponse
        | ScoreResponse
        | RerankResponse
        | None
    ) = None


class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

    response: BatchResponseData | None

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
    error: Any | None


129
def make_arg_parser(parser: FlexibleArgumentParser):
130
131
132
133
134
    parser.add_argument(
        "-i",
        "--input-file",
        required=True,
        type=str,
135
        help="The path or url to a single input file. Currently supports local file "
136
        "paths, or the http protocol (http or https). If a URL is specified, "
137
138
        "the file should be available via HTTP GET.",
    )
139
140
141
142
143
144
145
    parser.add_argument(
        "-o",
        "--output-file",
        required=True,
        type=str,
        help="The path or url to a single output file. Currently supports "
        "local file paths, or web (http or https) urls. If a URL is specified,"
146
147
        " the file should be available via HTTP PUT.",
    )
148
149
150
151
152
153
154
    parser.add_argument(
        "--output-tmp-dir",
        type=str,
        default=None,
        help="The directory to store the output file before uploading it "
        "to the output URL.",
    )
155
156
157
158
159
160
    parser.add_argument(
        "--response-role",
        type=optional_type(str),
        default="assistant",
        help="The role name to return if `request.add_generation_prompt=True`.",
    )
161
162

    parser = AsyncEngineArgs.add_cli_args(parser)
163

164
165
166
167
168
169
170
171
    parser.add_argument(
        "--max-log-len",
        type=int,
        default=None,
        help="Max number of prompt characters or prompt "
        "ID numbers being printed in log."
        "\n\nDefault: Unlimited",
    )
172

173
174
175
    parser.add_argument(
        "--enable-metrics", action="store_true", help="Enable Prometheus metrics"
    )
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    parser.add_argument(
        "--url",
        type=str,
        default="0.0.0.0",
        help="URL to the Prometheus metrics server "
        "(only needed if enable-metrics is set).",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="Port number for the Prometheus metrics server "
        "(only needed if enable-metrics is set).",
    )
190
191
    parser.add_argument(
        "--enable-prompt-tokens-details",
192
        action="store_true",
193
        default=False,
194
195
        help="If set to True, enable prompt_tokens_details in usage.",
    )
196
197
198
199
200
201
202
    parser.add_argument(
        "--enable-force-include-usage",
        action="store_true",
        default=False,
        help="If set to True, include usage on every request "
        "(even when stream_options is not specified)",
    )
203

204
205
206
207
    return parser


def parse_args():
208
    parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.")
209
    return make_arg_parser(parser).parse_args()
210
211


212
213
214
215
216
217
218
219
220
221
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n"  # noqa: E501


class BatchProgressTracker:
    def __init__(self):
        self._total = 0
222
        self._pbar: tqdm | None = None
223
224
225
226
227
228
229
230
231

    def submitted(self):
        self._total += 1

    def completed(self):
        if self._pbar:
            self._pbar.update()

    def pbar(self) -> tqdm:
232
233
234
235
236
237
238
239
240
241
242
        enable_tqdm = (
            not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
        )
        self._pbar = tqdm(
            total=self._total,
            unit="req",
            desc="Running batch",
            mininterval=5,
            disable=not enable_tqdm,
            bar_format=_BAR_FORMAT,
        )
243
244
245
        return self._pbar


246
247
async def read_file(path_or_url: str) -> str:
    if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
248
        async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp:
249
250
            return await resp.text()
    else:
251
        with open(path_or_url, encoding="utf-8") as f:
252
253
254
            return f.read()


255
256
257
async def write_local_file(
    output_path: str, batch_outputs: list[BatchRequestOutput]
) -> None:
258
259
260
261
262
263
    """
    Write the responses to a local file.
    output_path: The path to write the responses to.
    batch_outputs: The list of batch outputs to write.
    """
    # We should make this async, but as long as run_batch runs as a
264
    # standalone program, blocking the event loop won't affect performance.
265
266
267
268
269
    with open(output_path, "w", encoding="utf-8") as f:
        for o in batch_outputs:
            print(o.model_dump_json(), file=f)


270
async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None:
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    """
    Upload a local file to a URL.
    output_url: The URL to upload the file to.
    data_or_file: Either the data to upload or the path to the file to upload.
    from_file: If True, data_or_file is the path to the file to upload.
    """
    # Timeout is a common issue when uploading large files.
    # We retry max_retries times before giving up.
    max_retries = 5
    # Number of seconds to wait before retrying.
    delay = 5

    for attempt in range(1, max_retries + 1):
        try:
            # We increase the timeout to 1000 seconds to allow
            # for large files (default is 300).
287
288
289
            async with aiohttp.ClientSession(
                timeout=aiohttp.ClientTimeout(total=1000)
            ) as session:
290
291
                if from_file:
                    with open(data_or_file, "rb") as file:
292
                        async with session.put(output_url, data=file) as response:
293
                            if response.status != 200:
294
295
296
297
298
                                raise Exception(
                                    f"Failed to upload file.\n"
                                    f"Status: {response.status}\n"
                                    f"Response: {response.text()}"
                                )
299
                else:
300
                    async with session.put(output_url, data=data_or_file) as response:
301
                        if response.status != 200:
302
303
304
305
306
                            raise Exception(
                                f"Failed to upload data.\n"
                                f"Status: {response.status}\n"
                                f"Response: {response.text()}"
                            )
307
308
309
310

        except Exception as e:
            if attempt < max_retries:
                logger.error(
311
312
313
314
                    "Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...",  # noqa: E501
                    attempt,
                    e,
                    delay,
315
316
317
                )
                await asyncio.sleep(delay)
            else:
318
319
320
                raise Exception(
                    f"Failed to upload data (attempt {attempt}). Error message: {str(e)}."  # noqa: E501
                ) from e
321
322


323
324
325
async def write_file(
    path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str
) -> None:
326
327
328
329
330
331
332
    """
    Write batch_outputs to a file or upload to a URL.
    path_or_url: The path or URL to write batch_outputs to.
    batch_outputs: The list of batch outputs to write.
    output_tmp_dir: The directory to store the output file before uploading it
    to the output URL.
    """
333
    if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        if output_tmp_dir is None:
            logger.info("Writing outputs to memory buffer")
            output_buffer = StringIO()
            for o in batch_outputs:
                print(o.model_dump_json(), file=output_buffer)
            output_buffer.seek(0)
            logger.info("Uploading outputs to %s", path_or_url)
            await upload_data(
                path_or_url,
                output_buffer.read().strip().encode("utf-8"),
                from_file=False,
            )
        else:
            # Write responses to a temporary file and then upload it to the URL.
            with tempfile.NamedTemporaryFile(
349
350
351
352
353
                mode="w",
                encoding="utf-8",
                dir=output_tmp_dir,
                prefix="tmp_batch_output_",
                suffix=".jsonl",
354
            ) as f:
355
                logger.info("Writing outputs to temporary local file %s", f.name)
356
357
358
                await write_local_file(f.name, batch_outputs)
                logger.info("Uploading outputs to %s", path_or_url)
                await upload_data(path_or_url, f.name, from_file=True)
359
    else:
360
361
        logger.info("Writing outputs to local file %s", path_or_url)
        await write_local_file(path_or_url, batch_outputs)
362
363


364
365
366
def make_error_request_output(
    request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
367
368
369
370
371
372
373
374
375
376
377
378
379
    batch_output = BatchRequestOutput(
        id=f"vllm-{random_uuid()}",
        custom_id=request.custom_id,
        response=BatchResponseData(
            status_code=HTTPStatus.BAD_REQUEST,
            request_id=f"vllm-batch-{random_uuid()}",
        ),
        error=error_msg,
    )
    return batch_output


async def make_async_error_request_output(
380
381
    request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
382
383
384
    return make_error_request_output(request, error_msg)


385
386
387
388
389
async def run_request(
    serving_engine_func: Callable,
    request: BatchRequestInput,
    tracker: BatchProgressTracker,
) -> BatchRequestOutput:
390
    response = await serving_engine_func(request.body)
391

392
    if isinstance(
393
394
        response,
        (ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse),
395
    ):
396
397
398
        batch_output = BatchRequestOutput(
            id=f"vllm-{random_uuid()}",
            custom_id=request.custom_id,
399
            response=BatchResponseData(
400
401
                body=response, request_id=f"vllm-batch-{random_uuid()}"
            ),
402
403
            error=None,
        )
404
    elif isinstance(response, ErrorResponse):
405
406
407
        batch_output = BatchRequestOutput(
            id=f"vllm-{random_uuid()}",
            custom_id=request.custom_id,
408
            response=BatchResponseData(
409
                status_code=response.error.code,
410
411
                request_id=f"vllm-batch-{random_uuid()}",
            ),
412
            error=response,
413
        )
414
    else:
415
        batch_output = make_error_request_output(
416
417
            request, error_msg="Request must not be sent in stream mode"
        )
418

419
    tracker.completed()
420
421
422
    return batch_output


423
def validate_run_batch_args(args):
424
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
425
426
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
427
    ) and reasoning_parser not in valid_reasoning_parsers:
428
429
        raise KeyError(
            f"invalid reasoning parser: {reasoning_parser} "
430
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
431
432
433
        )


434
435
436
437
async def run_batch(
    engine_client: EngineClient,
    args: Namespace,
) -> None:
438
439
440
441
442
    if args.served_model_name is not None:
        served_model_names = args.served_model_name
    else:
        served_model_names = [args.model]

443
    if args.enable_log_requests:
444
        request_logger = RequestLogger(max_log_len=args.max_log_len)
445
446
    else:
        request_logger = None
447

448
    base_model_paths = [
449
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
450
    ]
451

452
    model_config = engine_client.model_config
453
    supported_tasks = await engine_client.get_supported_tasks()
454
    logger.info("Supported tasks: %s", supported_tasks)
455

456
    # Create the openai serving objects.
457
    openai_serving_models = OpenAIServingModels(
458
        engine_client=engine_client,
459
460
461
        base_model_paths=base_model_paths,
        lora_modules=None,
    )
462

463
464
465
466
467
468
469
470
    openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=None,
            chat_template_content_format="auto",
471
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
472
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
473
            enable_force_include_usage=args.enable_force_include_usage,
474
475
476
            default_chat_template_kwargs=getattr(
                args, "default_chat_template_kwargs", None
            ),
477
478
479
480
        )
        if "generate" in supported_tasks
        else None
    )
481

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    openai_serving_embedding = (
        OpenAIServingEmbedding(
            engine_client,
            openai_serving_models,
            request_logger=request_logger,
            chat_template=None,
            chat_template_content_format="auto",
        )
        if "embed" in supported_tasks
        else None
    )

    enable_serving_reranking = (
        "classify" in supported_tasks
        and getattr(model_config.hf_config, "num_labels", 0) == 1
    )

    openai_serving_scores = (
        ServingScores(
            engine_client,
            openai_serving_models,
            request_logger=request_logger,
504
            score_template=None,
505
506
507
508
        )
        if ("embed" in supported_tasks or enable_serving_reranking)
        else None
    )
509

510
511
512
    tracker = BatchProgressTracker()
    logger.info("Reading batch from %s...", args.input_file)

513
    # Submit all requests in the file to the engine "concurrently".
514
    response_futures: list[Awaitable[BatchRequestOutput]] = []
515
    for request_json in (await read_file(args.input_file)).strip().split("\n"):
516
517
518
519
520
        # Skip empty lines.
        request_json = request_json.strip()
        if not request_json:
            continue

521
        request = BatchRequestInput.model_validate_json(request_json)
522
523
524

        # Determine the type of request and run it.
        if request.url == "/v1/chat/completions":
525
526
527
528
529
            chat_handler_fn = (
                openai_serving_chat.create_chat_completion
                if openai_serving_chat is not None
                else None
            )
530
            if chat_handler_fn is None:
531
532
533
                response_futures.append(
                    make_async_error_request_output(
                        request,
534
535
536
                        error_msg="The model does not support Chat Completions API",
                    )
                )
537
538
                continue

539
            response_futures.append(run_request(chat_handler_fn, request, tracker))
540
            tracker.submitted()
541
        elif request.url == "/v1/embeddings":
542
543
544
545
546
            embed_handler_fn = (
                openai_serving_embedding.create_embedding
                if openai_serving_embedding is not None
                else None
            )
547
            if embed_handler_fn is None:
548
549
550
551
                response_futures.append(
                    make_async_error_request_output(
                        request,
                        error_msg="The model does not support Embeddings API",
552
553
                    )
                )
554
555
                continue

556
            response_futures.append(run_request(embed_handler_fn, request, tracker))
557
            tracker.submitted()
558
        elif request.url.endswith("/score"):
559
560
561
562
563
            score_handler_fn = (
                openai_serving_scores.create_score
                if openai_serving_scores is not None
                else None
            )
564
            if score_handler_fn is None:
565
566
567
568
                response_futures.append(
                    make_async_error_request_output(
                        request,
                        error_msg="The model does not support Scores API",
569
570
                    )
                )
571
572
                continue

573
            response_futures.append(run_request(score_handler_fn, request, tracker))
574
            tracker.submitted()
575
        elif request.url.endswith("/rerank"):
576
577
578
579
580
            rerank_handler_fn = (
                openai_serving_scores.do_rerank
                if openai_serving_scores is not None
                else None
            )
581
582
583
584
585
            if rerank_handler_fn is None:
                response_futures.append(
                    make_async_error_request_output(
                        request,
                        error_msg="The model does not support Rerank API",
586
587
                    )
                )
588
589
                continue

590
            response_futures.append(run_request(rerank_handler_fn, request, tracker))
591
            tracker.submitted()
592
        else:
593
594
595
            response_futures.append(
                make_async_error_request_output(
                    request,
596
597
598
599
600
                    error_msg=f"URL {request.url} was used. "
                    "Supported endpoints: /v1/chat/completions, /v1/embeddings,"
                    " /score, /rerank ."
                    "See vllm/entrypoints/openai/api_server.py for supported "
                    "score/rerank versions.",
601
602
                )
            )
603

604
605
    with tracker.pbar():
        responses = await asyncio.gather(*response_futures)
606

607
    await write_file(args.output_file, responses, args.output_tmp_dir)
608
609


610
async def main(args: Namespace):
611
612
613
    from vllm.entrypoints.openai.api_server import build_async_engine_client
    from vllm.usage.usage_lib import UsageContext

614
615
    validate_run_batch_args(args)

616
    async with build_async_engine_client(
617
618
619
        args,
        usage_context=UsageContext.OPENAI_BATCH_RUNNER,
        disable_frontend_multiprocessing=False,
620
    ) as engine_client:
621
        await run_batch(engine_client, args)
622
623


624
625
626
if __name__ == "__main__":
    args = parse_args()

627
    logger.info("vLLM batch processing API version %s", VLLM_VERSION)
628
629
    logger.info("args: %s", args)

630
631
632
633
634
635
636
637
    # Start the Prometheus metrics server. LLMEngine uses the Prometheus client
    # to publish metrics at the /metrics endpoint.
    if args.enable_metrics:
        logger.info("Prometheus metrics enabled")
        start_http_server(port=args.port, addr=args.url)
    else:
        logger.info("Prometheus metrics disabled")

638
    asyncio.run(main(args))