Unverified Commit 3ce2c050 authored by zifeitong's avatar zifeitong Committed by GitHub
Browse files

[Fix] Correct OpenAI batch response format (#5554)

parent 1c0afa13
...@@ -672,6 +672,17 @@ class BatchRequestInput(OpenAIBaseModel): ...@@ -672,6 +672,17 @@ class BatchRequestInput(OpenAIBaseModel):
body: Union[ChatCompletionRequest, ] body: Union[ChatCompletionRequest, ]
class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response.
status_code: int = 200
# An unique identifier for the API request.
request_id: str
# The body of the response.
body: Union[ChatCompletionResponse, ]
class BatchRequestOutput(OpenAIBaseModel): class BatchRequestOutput(OpenAIBaseModel):
""" """
The per-line object of the batch output and error files The per-line object of the batch output and error files
...@@ -683,7 +694,7 @@ class BatchRequestOutput(OpenAIBaseModel): ...@@ -683,7 +694,7 @@ class BatchRequestOutput(OpenAIBaseModel):
# inputs. # inputs.
custom_id: str custom_id: str
response: Optional[ChatCompletionResponse] response: Optional[BatchResponseData]
# For requests that failed with a non-HTTP error, this will contain more # For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure. # information on the cause of the failure.
......
...@@ -10,7 +10,9 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str ...@@ -10,7 +10,9 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (BatchRequestInput, from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput, BatchRequestOutput,
ChatCompletionResponse) BatchResponseData,
ChatCompletionResponse,
ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -77,20 +79,27 @@ async def run_request(chat_serving: OpenAIServingChat, ...@@ -77,20 +79,27 @@ async def run_request(chat_serving: OpenAIServingChat,
request: BatchRequestInput) -> BatchRequestOutput: request: BatchRequestInput) -> BatchRequestOutput:
chat_request = request.body chat_request = request.body
chat_response = await chat_serving.create_chat_completion(chat_request) chat_response = await chat_serving.create_chat_completion(chat_request)
if isinstance(chat_response, ChatCompletionResponse): if isinstance(chat_response, ChatCompletionResponse):
batch_output = BatchRequestOutput( batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}", id=f"vllm-{random_uuid()}",
custom_id=request.custom_id, custom_id=request.custom_id,
response=chat_response, response=BatchResponseData(
body=chat_response, request_id=f"vllm-batch-{random_uuid()}"),
error=None, error=None,
) )
else: elif isinstance(chat_response, ErrorResponse):
batch_output = BatchRequestOutput( batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}", id=f"vllm-{random_uuid()}",
custom_id=request.custom_id, custom_id=request.custom_id,
response=None, response=BatchResponseData(
status_code=chat_response.code,
request_id=f"vllm-batch-{random_uuid()}"),
error=chat_response, error=chat_response,
) )
else:
raise ValueError("Request must not be sent in stream mode")
return batch_output return batch_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment