run_batch.py 12.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
from http import HTTPStatus
5
from io import StringIO
6
from typing import Awaitable, Callable, List, Optional
7
8

import aiohttp
9
import torch
10
from prometheus_client import start_http_server
11
from tqdm import tqdm
12
13
14

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
15
from vllm.entrypoints.logger import RequestLogger, logger
16
# yapf: disable
17
18
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
                                              BatchRequestOutput,
19
20
                                              BatchResponseData,
                                              ChatCompletionResponse,
21
22
                                              EmbeddingResponse, ErrorResponse,
                                              ScoreResponse)
23
# yapf: enable
24
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
25
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
26
27
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
                                                    OpenAIServingModels)
28
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
29
from vllm.usage.usage_lib import UsageContext
30
from vllm.utils import FlexibleArgumentParser, random_uuid
31
from vllm.version import __version__ as VLLM_VERSION
32
33
34


def parse_args():
35
    parser = FlexibleArgumentParser(
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        description="vLLM OpenAI-Compatible batch runner.")
    parser.add_argument(
        "-i",
        "--input-file",
        required=True,
        type=str,
        help=
        "The path or url to a single input file. Currently supports local file "
        "paths, or the http protocol (http or https). If a URL is specified, "
        "the file should be available via HTTP GET.")
    parser.add_argument(
        "-o",
        "--output-file",
        required=True,
        type=str,
        help="The path or url to a single output file. Currently supports "
        "local file paths, or web (http or https) urls. If a URL is specified,"
        " the file should be available via HTTP PUT.")
    parser.add_argument("--response-role",
                        type=nullable_str,
                        default="assistant",
                        help="The role name to return if "
58
                        "`request.add_generation_prompt=True`.")
59
60

    parser = AsyncEngineArgs.add_cli_args(parser)
61
62
63
64
65
66
67
68

    parser.add_argument('--max-log-len',
                        type=int,
                        default=None,
                        help='Max number of prompt characters or prompt '
                        'ID numbers being printed in log.'
                        '\n\nDefault: Unlimited')

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    parser.add_argument("--enable-metrics",
                        action="store_true",
                        help="Enable Prometheus metrics")
    parser.add_argument(
        "--url",
        type=str,
        default="0.0.0.0",
        help="URL to the Prometheus metrics server "
        "(only needed if enable-metrics is set).",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="Port number for the Prometheus metrics server "
        "(only needed if enable-metrics is set).",
    )
86
87
88
89
90
    parser.add_argument(
        "--enable-prompt-tokens-details",
        action='store_true',
        default=False,
        help="If set to True, enable prompt_tokens_details in usage.")
91

92
93
94
    return parser.parse_args()


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n"  # noqa: E501


class BatchProgressTracker:

    def __init__(self):
        self._total = 0
        self._pbar: Optional[tqdm] = None

    def submitted(self):
        self._total += 1

    def completed(self):
        if self._pbar:
            self._pbar.update()

    def pbar(self) -> tqdm:
        enable_tqdm = not torch.distributed.is_initialized(
        ) or torch.distributed.get_rank() == 0
        self._pbar = tqdm(total=self._total,
                          unit="req",
                          desc="Running batch",
                          mininterval=5,
                          disable=not enable_tqdm,
                          bar_format=_BAR_FORMAT)
        return self._pbar


127
128
129
130
131
132
async def read_file(path_or_url: str) -> str:
    if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
        async with aiohttp.ClientSession() as session, \
                   session.get(path_or_url) as resp:
            return await resp.text()
    else:
133
        with open(path_or_url, encoding="utf-8") as f:
134
135
136
137
138
139
140
141
142
143
144
145
            return f.read()


async def write_file(path_or_url: str, data: str) -> None:
    if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
        async with aiohttp.ClientSession() as session, \
                   session.put(path_or_url, data=data.encode("utf-8")):
            pass
    else:
        # We should make this async, but as long as this is always run as a
        # standalone program, blocking the event loop won't effect performance
        # in this particular case.
146
        with open(path_or_url, "w", encoding="utf-8") as f:
147
148
149
            f.write(data)


150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def make_error_request_output(request: BatchRequestInput,
                              error_msg: str) -> BatchRequestOutput:
    batch_output = BatchRequestOutput(
        id=f"vllm-{random_uuid()}",
        custom_id=request.custom_id,
        response=BatchResponseData(
            status_code=HTTPStatus.BAD_REQUEST,
            request_id=f"vllm-batch-{random_uuid()}",
        ),
        error=error_msg,
    )
    return batch_output


async def make_async_error_request_output(
        request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
    return make_error_request_output(request, error_msg)


169
async def run_request(serving_engine_func: Callable,
170
171
                      request: BatchRequestInput,
                      tracker: BatchProgressTracker) -> BatchRequestOutput:
172
    response = await serving_engine_func(request.body)
173

174
175
    if isinstance(response,
                  (ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
176
177
178
        batch_output = BatchRequestOutput(
            id=f"vllm-{random_uuid()}",
            custom_id=request.custom_id,
179
            response=BatchResponseData(
180
                body=response, request_id=f"vllm-batch-{random_uuid()}"),
181
182
            error=None,
        )
183
    elif isinstance(response, ErrorResponse):
184
185
186
        batch_output = BatchRequestOutput(
            id=f"vllm-{random_uuid()}",
            custom_id=request.custom_id,
187
            response=BatchResponseData(
188
                status_code=response.code,
189
                request_id=f"vllm-batch-{random_uuid()}"),
190
            error=response,
191
        )
192
    else:
193
194
        batch_output = make_error_request_output(
            request, error_msg="Request must not be sent in stream mode")
195

196
    tracker.completed()
197
198
199
200
201
202
203
204
205
206
207
    return batch_output


async def main(args):
    if args.served_model_name is not None:
        served_model_names = args.served_model_name
    else:
        served_model_names = [args.model]

    engine_args = AsyncEngineArgs.from_cli_args(args)
    engine = AsyncLLMEngine.from_engine_args(
208
        engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
209
210

    model_config = await engine.get_model_config()
211
212
213
214
    base_model_paths = [
        BaseModelPath(name=name, model_path=args.model)
        for name in served_model_names
    ]
215

216
217
218
219
220
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

221
    # Create the openai serving objects.
222
    openai_serving_models = OpenAIServingModels(
223
        engine_client=engine,
224
225
226
227
228
        model_config=model_config,
        base_model_paths=base_model_paths,
        lora_modules=None,
        prompt_adapters=None,
    )
229
230
231
    openai_serving_chat = OpenAIServingChat(
        engine,
        model_config,
232
        openai_serving_models,
233
        args.response_role,
234
235
        request_logger=request_logger,
        chat_template=None,
236
        chat_template_content_format="auto",
237
        enable_prompt_tokens_details=args.enable_prompt_tokens_details,
238
    ) if model_config.runner_type == "generate" else None
239
240
241
    openai_serving_embedding = OpenAIServingEmbedding(
        engine,
        model_config,
242
        openai_serving_models,
243
        request_logger=request_logger,
244
        chat_template=None,
245
        chat_template_content_format="auto",
246
    ) if model_config.task == "embed" else None
247
248
249
250
251
252
    openai_serving_scores = (OpenAIServingScores(
        engine,
        model_config,
        openai_serving_models,
        request_logger=request_logger,
    ) if model_config.task == "score" else None)
253

254
255
256
    tracker = BatchProgressTracker()
    logger.info("Reading batch from %s...", args.input_file)

257
    # Submit all requests in the file to the engine "concurrently".
258
    response_futures: List[Awaitable[BatchRequestOutput]] = []
259
    for request_json in (await read_file(args.input_file)).strip().split("\n"):
260
261
262
263
264
        # Skip empty lines.
        request_json = request_json.strip()
        if not request_json:
            continue

265
        request = BatchRequestInput.model_validate_json(request_json)
266
267
268

        # Determine the type of request and run it.
        if request.url == "/v1/chat/completions":
269
270
271
272
273
274
275
276
277
278
279
280
            handler_fn = (None if openai_serving_chat is None else
                          openai_serving_chat.create_chat_completion)
            if handler_fn is None:
                response_futures.append(
                    make_async_error_request_output(
                        request,
                        error_msg=
                        "The model does not support Chat Completions API",
                    ))
                continue

            response_futures.append(run_request(handler_fn, request, tracker))
281
            tracker.submitted()
282
        elif request.url == "/v1/embeddings":
283
284
285
286
287
288
289
290
291
292
            handler_fn = (None if openai_serving_embedding is None else
                          openai_serving_embedding.create_embedding)
            if handler_fn is None:
                response_futures.append(
                    make_async_error_request_output(
                        request,
                        error_msg="The model does not support Embeddings API",
                    ))
                continue

293
294
295
296
297
298
299
300
301
302
303
304
305
            response_futures.append(run_request(handler_fn, request, tracker))
            tracker.submitted()
        elif request.url == "/v1/score":
            handler_fn = (None if openai_serving_scores is None else
                          openai_serving_scores.create_score)
            if handler_fn is None:
                response_futures.append(
                    make_async_error_request_output(
                        request,
                        error_msg="The model does not support Scores API",
                    ))
                continue

306
            response_futures.append(run_request(handler_fn, request, tracker))
307
            tracker.submitted()
308
        else:
309
310
311
            response_futures.append(
                make_async_error_request_output(
                    request,
312
313
314
                    error_msg=
                    "Only /v1/chat/completions, /v1/embeddings, and /v1/score "
                    "are supported in the batch endpoint.",
315
                ))
316

317
318
    with tracker.pbar():
        responses = await asyncio.gather(*response_futures)
319
320
321
322
323
324
325
326
327
328
329
330

    output_buffer = StringIO()
    for response in responses:
        print(response.model_dump_json(), file=output_buffer)

    output_buffer.seek(0)
    await write_file(args.output_file, output_buffer.read().strip())


if __name__ == "__main__":
    args = parse_args()

331
    logger.info("vLLM batch processing API version %s", VLLM_VERSION)
332
333
    logger.info("args: %s", args)

334
335
336
337
338
339
340
341
    # Start the Prometheus metrics server. LLMEngine uses the Prometheus client
    # to publish metrics at the /metrics endpoint.
    if args.enable_metrics:
        logger.info("Prometheus metrics enabled")
        start_http_server(port=args.port, addr=args.url)
    else:
        logger.info("Prometheus metrics disabled")

342
    asyncio.run(main(args))