run_batch.py 6.71 KB
Newer Older
1
2
import asyncio
from io import StringIO
3
from typing import Awaitable, Callable, List
4
5
6
7
8

import aiohttp

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
9
from vllm.entrypoints.logger import RequestLogger
10
# yapf: disable
11
12
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
                                              BatchRequestOutput,
13
14
                                              BatchResponseData,
                                              ChatCompletionResponse,
15
16
                                              EmbeddingResponse, ErrorResponse)
# yapf: enable
17
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
18
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
19
20
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
21
from vllm.utils import FlexibleArgumentParser, random_uuid
22
from vllm.version import __version__ as VLLM_VERSION
23
24
25
26
27

logger = init_logger(__name__)


def parse_args():
28
    parser = FlexibleArgumentParser(
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        description="vLLM OpenAI-Compatible batch runner.")
    parser.add_argument(
        "-i",
        "--input-file",
        required=True,
        type=str,
        help=
        "The path or url to a single input file. Currently supports local file "
        "paths, or the http protocol (http or https). If a URL is specified, "
        "the file should be available via HTTP GET.")
    parser.add_argument(
        "-o",
        "--output-file",
        required=True,
        type=str,
        help="The path or url to a single output file. Currently supports "
        "local file paths, or web (http or https) urls. If a URL is specified,"
        " the file should be available via HTTP PUT.")
    parser.add_argument("--response-role",
                        type=nullable_str,
                        default="assistant",
                        help="The role name to return if "
51
                        "`request.add_generation_prompt=True`.")
52
53

    parser = AsyncEngineArgs.add_cli_args(parser)
54
55
56
57
58
59
60
61

    parser.add_argument('--max-log-len',
                        type=int,
                        default=None,
                        help='Max number of prompt characters or prompt '
                        'ID numbers being printed in log.'
                        '\n\nDefault: Unlimited')

62
63
64
65
66
67
68
69
70
    return parser.parse_args()


async def read_file(path_or_url: str) -> str:
    if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
        async with aiohttp.ClientSession() as session, \
                   session.get(path_or_url) as resp:
            return await resp.text()
    else:
71
        with open(path_or_url, "r", encoding="utf-8") as f:
72
73
74
75
76
77
78
79
80
81
82
83
            return f.read()


async def write_file(path_or_url: str, data: str) -> None:
    if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
        async with aiohttp.ClientSession() as session, \
                   session.put(path_or_url, data=data.encode("utf-8")):
            pass
    else:
        # We should make this async, but as long as this is always run as a
        # standalone program, blocking the event loop won't effect performance
        # in this particular case.
84
        with open(path_or_url, "w", encoding="utf-8") as f:
85
86
87
            f.write(data)


88
async def run_request(serving_engine_func: Callable,
89
                      request: BatchRequestInput) -> BatchRequestOutput:
90
    response = await serving_engine_func(request.body)
91

92
    if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
93
94
95
        batch_output = BatchRequestOutput(
            id=f"vllm-{random_uuid()}",
            custom_id=request.custom_id,
96
            response=BatchResponseData(
97
                body=response, request_id=f"vllm-batch-{random_uuid()}"),
98
99
            error=None,
        )
100
    elif isinstance(response, ErrorResponse):
101
102
103
        batch_output = BatchRequestOutput(
            id=f"vllm-{random_uuid()}",
            custom_id=request.custom_id,
104
            response=BatchResponseData(
105
                status_code=response.code,
106
                request_id=f"vllm-batch-{random_uuid()}"),
107
            error=response,
108
        )
109
110
111
    else:
        raise ValueError("Request must not be sent in stream mode")

112
113
114
115
116
117
118
119
120
121
122
    return batch_output


async def main(args):
    if args.served_model_name is not None:
        served_model_names = args.served_model_name
    else:
        served_model_names = [args.model]

    engine_args = AsyncEngineArgs.from_cli_args(args)
    engine = AsyncLLMEngine.from_engine_args(
123
        engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
124
125
126
127

    # When using single vLLM without engine_use_ray
    model_config = await engine.get_model_config()

128
129
130
131
132
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

133
    # Create the openai serving objects.
134
135
136
137
138
    openai_serving_chat = OpenAIServingChat(
        engine,
        model_config,
        served_model_names,
        args.response_role,
139
140
141
142
        lora_modules=None,
        prompt_adapters=None,
        request_logger=request_logger,
        chat_template=None,
143
    )
144
145
146
147
148
149
    openai_serving_embedding = OpenAIServingEmbedding(
        engine,
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
150
151

    # Submit all requests in the file to the engine "concurrently".
152
    response_futures: List[Awaitable[BatchRequestOutput]] = []
153
    for request_json in (await read_file(args.input_file)).strip().split("\n"):
154
155
156
157
158
        # Skip empty lines.
        request_json = request_json.strip()
        if not request_json:
            continue

159
        request = BatchRequestInput.model_validate_json(request_json)
160
161
162
163
164
165
166
167
168
169
170
171
172

        # Determine the type of request and run it.
        if request.url == "/v1/chat/completions":
            response_futures.append(
                run_request(openai_serving_chat.create_chat_completion,
                            request))
        elif request.url == "/v1/embeddings":
            response_futures.append(
                run_request(openai_serving_embedding.create_embedding,
                            request))
        else:
            raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
                             "supported in the batch endpoint.")
173
174
175
176
177
178
179
180
181
182
183
184
185
186

    responses = await asyncio.gather(*response_futures)

    output_buffer = StringIO()
    for response in responses:
        print(response.model_dump_json(), file=output_buffer)

    output_buffer.seek(0)
    await write_file(args.output_file, output_buffer.read().strip())


if __name__ == "__main__":
    args = parse_args()

187
    logger.info("vLLM API server version %s", VLLM_VERSION)
188
189
190
    logger.info("args: %s", args)

    asyncio.run(main(args))