"vscode:/vscode.git/clone" did not exist on "dba4d9dec606da028fbb28240e99cabd5a761e6a"
Unverified Commit 2cb2340f authored by Pooya Davoodi's avatar Pooya Davoodi Committed by GitHub
Browse files

[Frontend]Add support for transcriptions and translations to run_batch (#33934)


Signed-off-by: default avatarPooya Davoodi <pooya.davoodi@parasail.io>
Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 4df44c16
...@@ -7,6 +7,7 @@ import tempfile ...@@ -7,6 +7,7 @@ import tempfile
import pytest import pytest
from vllm.assets.audio import AudioAsset
from vllm.entrypoints.openai.run_batch import BatchRequestOutput from vllm.entrypoints.openai.run_batch import BatchRequestOutput
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
...@@ -42,6 +43,27 @@ INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/re ...@@ -42,6 +43,27 @@ INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/re
INPUT_REASONING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Solve this math problem: 2+2=?"}]}} INPUT_REASONING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Solve this math problem: 2+2=?"}]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "What is the capital of France?"}]}}""" {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "What is the capital of France?"}]}}"""
# This is a valid but minimal audio file for testing
MINIMAL_WAV_BASE64 = "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAEAfAAABAAgAZGF0YQAAAAA="
INPUT_TRANSCRIPTION_BATCH = (
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/transcriptions", '
'"body": {{"model": "openai/whisper-large-v3", "file_url": "data:audio/wav;base64,{}", '
'"response_format": "json"}}}}\n'
).format(MINIMAL_WAV_BASE64)
INPUT_TRANSCRIPTION_HTTP_BATCH = (
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/transcriptions", '
'"body": {{"model": "openai/whisper-large-v3", "file_url": "{}", '
'"response_format": "json"}}}}\n'
).format(AudioAsset("mary_had_lamb").url)
INPUT_TRANSLATION_BATCH = (
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/translations", '
'"body": {{"model": "openai/whisper-small", "file_url": "{}", '
'"response_format": "text", "language": "it", "to_language": "en", '
'"temperature": 0.0}}}}\n'
).format(AudioAsset("mary_had_lamb").url)
def test_empty_file(): def test_empty_file():
with ( with (
...@@ -238,3 +260,121 @@ def test_reasoning_parser(): ...@@ -238,3 +260,121 @@ def test_reasoning_parser():
] ]
assert reasoning is not None assert reasoning is not None
assert len(reasoning) > 0 assert len(reasoning) > 0
def test_transcription():
with (
tempfile.NamedTemporaryFile("w") as input_file,
tempfile.NamedTemporaryFile("r") as output_file,
):
input_file.write(INPUT_TRANSCRIPTION_BATCH)
input_file.flush()
proc = subprocess.Popen(
[
"vllm",
"run-batch",
"-i",
input_file.name,
"-o",
output_file.name,
"--model",
"openai/whisper-large-v3",
],
)
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"
contents = output_file.read()
print(f"\n\ncontents: {contents}\n\n")
for line in contents.strip().split("\n"):
BatchRequestOutput.model_validate_json(line)
line_dict = json.loads(line)
assert isinstance(line_dict, dict)
assert line_dict["error"] is None
response_body = line_dict["response"]["body"]
assert response_body is not None
assert "text" in response_body
assert "usage" in response_body
def test_transcription_http_url():
with (
tempfile.NamedTemporaryFile("w") as input_file,
tempfile.NamedTemporaryFile("r") as output_file,
):
input_file.write(INPUT_TRANSCRIPTION_HTTP_BATCH)
input_file.flush()
proc = subprocess.Popen(
[
"vllm",
"run-batch",
"-i",
input_file.name,
"-o",
output_file.name,
"--model",
"openai/whisper-large-v3",
],
)
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"
contents = output_file.read()
for line in contents.strip().split("\n"):
BatchRequestOutput.model_validate_json(line)
line_dict = json.loads(line)
assert isinstance(line_dict, dict)
assert line_dict["error"] is None
response_body = line_dict["response"]["body"]
assert response_body is not None
assert "text" in response_body
assert "usage" in response_body
transcription_text = response_body["text"]
assert "Mary had a little lamb" in transcription_text
def test_translation():
with (
tempfile.NamedTemporaryFile("w") as input_file,
tempfile.NamedTemporaryFile("r") as output_file,
):
input_file.write(INPUT_TRANSLATION_BATCH)
input_file.flush()
proc = subprocess.Popen(
[
"vllm",
"run-batch",
"-i",
input_file.name,
"-o",
output_file.name,
"--model",
"openai/whisper-small",
],
)
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"
contents = output_file.read()
for line in contents.strip().split("\n"):
BatchRequestOutput.model_validate_json(line)
line_dict = json.loads(line)
assert isinstance(line_dict, dict)
assert line_dict["error"] is None
response_body = line_dict["response"]["body"]
assert response_body is not None
assert "text" in response_body
translation_text = response_body["text"]
translation_text_lower = str(translation_text).strip().lower()
assert "mary" in translation_text_lower or "lamb" in translation_text_lower
...@@ -2,17 +2,20 @@ ...@@ -2,17 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import base64
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from http import HTTPStatus from http import HTTPStatus
from io import StringIO from io import BytesIO, StringIO
from typing import Any, TypeAlias from typing import Any, TypeAlias
from urllib.parse import urlparse
import aiohttp import aiohttp
import torch import torch
from fastapi import UploadFile
from prometheus_client import start_http_server from prometheus_client import start_http_server
from pydantic import TypeAdapter, field_validator from pydantic import Field, TypeAdapter, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from tqdm import tqdm from tqdm import tqdm
...@@ -25,12 +28,28 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ...@@ -25,12 +28,28 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
) )
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse, ErrorResponse,
OpenAIBaseModel, OpenAIBaseModel,
) )
from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest, EmbeddingResponse from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseVerbose,
TranslationRequest,
TranslationResponse,
TranslationResponseVerbose,
)
from vllm.entrypoints.openai.speech_to_text.serving import (
OpenAIServingTranscription,
OpenAIServingTranslation,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.score.protocol import ( from vllm.entrypoints.pooling.score.protocol import (
RerankRequest, RerankRequest,
...@@ -41,6 +60,7 @@ from vllm.entrypoints.pooling.score.protocol import ( ...@@ -41,6 +60,7 @@ from vllm.entrypoints.pooling.score.protocol import (
from vllm.entrypoints.pooling.score.serving import ServingScores from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.tasks import SupportedTask
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
...@@ -48,8 +68,73 @@ from vllm.version import __version__ as VLLM_VERSION ...@@ -48,8 +68,73 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
class BatchTranscriptionRequest(TranscriptionRequest):
"""
Batch transcription request that uses file_url instead of file.
This class extends TranscriptionRequest but replaces the file field
with file_url to support batch processing from audio files written in JSON format.
"""
file_url: str = Field(
...,
description=(
"Either a URL of the audio or a data URL with base64 encoded audio data. "
),
)
# Override file to be optional and unused for batch processing
file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment]
@model_validator(mode="before")
@classmethod
def validate_no_file(cls, data: Any):
"""Ensure file field is not provided in batch requests."""
if isinstance(data, dict) and "file" in data:
raise ValueError(
"The 'file' field is not supported in batch requests. "
"Use 'file_url' instead."
)
return data
class BatchTranslationRequest(TranslationRequest):
"""
Batch translation request that uses file_url instead of file.
This class extends TranslationRequest but replaces the file field
with file_url to support batch processing from audio files written in JSON format.
"""
file_url: str = Field(
...,
description=(
"Either a URL of the audio or a data URL with base64 encoded audio data. "
),
)
# Override file to be optional and unused for batch processing
file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment]
@model_validator(mode="before")
@classmethod
def validate_no_file(cls, data: Any):
"""Ensure file field is not provided in batch requests."""
if isinstance(data, dict) and "file" in data:
raise ValueError(
"The 'file' field is not supported in batch requests. "
"Use 'file_url' instead."
)
return data
BatchRequestInputBody: TypeAlias = ( BatchRequestInputBody: TypeAlias = (
ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest ChatCompletionRequest
| EmbeddingRequest
| ScoreRequest
| RerankRequest
| BatchTranscriptionRequest
| BatchTranslationRequest
) )
...@@ -88,6 +173,10 @@ class BatchRequestInput(OpenAIBaseModel): ...@@ -88,6 +173,10 @@ class BatchRequestInput(OpenAIBaseModel):
return TypeAdapter(ScoreRequest).validate_python(value) return TypeAdapter(ScoreRequest).validate_python(value)
if url.endswith("/rerank"): if url.endswith("/rerank"):
return RerankRequest.model_validate(value) return RerankRequest.model_validate(value)
if url == "/v1/audio/transcriptions":
return BatchTranscriptionRequest.model_validate(value)
if url == "/v1/audio/translations":
return BatchTranslationRequest.model_validate(value)
return TypeAdapter(BatchRequestInputBody).validate_python(value) return TypeAdapter(BatchRequestInputBody).validate_python(value)
...@@ -104,6 +193,10 @@ class BatchResponseData(OpenAIBaseModel): ...@@ -104,6 +193,10 @@ class BatchResponseData(OpenAIBaseModel):
| EmbeddingResponse | EmbeddingResponse
| ScoreResponse | ScoreResponse
| RerankResponse | RerankResponse
| TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
| None | None
) = None ) = None
...@@ -361,6 +454,49 @@ async def write_file( ...@@ -361,6 +454,49 @@ async def write_file(
await write_local_file(path_or_url, batch_outputs) await write_local_file(path_or_url, batch_outputs)
async def download_bytes_from_url(url: str) -> bytes:
"""
Download data from a URL or decode from a data URL.
Args:
url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...)
Returns:
Data as bytes
"""
parsed = urlparse(url)
# Handle data URLs (base64 encoded)
if parsed.scheme == "data":
# Format: data:...;base64,<base64_data>
if "," in url:
header, data = url.split(",", 1)
if "base64" in header:
return base64.b64decode(data)
else:
raise ValueError(f"Unsupported data URL encoding: {header}")
else:
raise ValueError(f"Invalid data URL format: {url}")
# Handle HTTP/HTTPS URLs
elif parsed.scheme in ("http", "https"):
async with (
aiohttp.ClientSession() as session,
session.get(url) as resp,
):
if resp.status != 200:
raise Exception(
f"Failed to download data from URL: {url}. Status: {resp.status}"
)
return await resp.read()
else:
raise ValueError(
f"Unsupported URL scheme: {parsed.scheme}. "
"Supported schemes: http, https, data"
)
def make_error_request_output( def make_error_request_output(
request: BatchRequestInput, error_msg: str request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput: ) -> BatchRequestOutput:
...@@ -391,7 +527,16 @@ async def run_request( ...@@ -391,7 +527,16 @@ async def run_request(
if isinstance( if isinstance(
response, response,
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse), (
ChatCompletionResponse,
EmbeddingResponse,
ScoreResponse,
RerankResponse,
TranscriptionResponse,
TranscriptionResponseVerbose,
TranslationResponse,
TranslationResponseVerbose,
),
): ):
batch_output = BatchRequestOutput( batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}", id=f"vllm-{random_uuid()}",
...@@ -420,38 +565,130 @@ async def run_request( ...@@ -420,38 +565,130 @@ async def run_request(
return batch_output return batch_output
def validate_run_batch_args(args): def handle_endpoint_request(
valid_reasoning_parsers = ReasoningParserManager.list_registered() request: BatchRequestInput,
if ( tracker: BatchProgressTracker,
reasoning_parser := args.structured_outputs_config.reasoning_parser url_matcher: Callable[[str], bool],
) and reasoning_parser not in valid_reasoning_parsers: handler_getter: Callable[[], Callable | None],
raise KeyError( wrapper_fn: Callable[[Callable], Callable] | None = None,
f"invalid reasoning parser: {reasoning_parser} " ) -> Awaitable[BatchRequestOutput] | None:
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})" """
) Generic handler for endpoint requests.
Args:
request: The batch request input
tracker: Progress tracker for the batch
url_matcher: Function that takes a URL and returns True if it matches
handler_getter: Function that returns the handler function or None
wrapper_fn: Optional function to wrap the handler (e.g., for transcriptions)
Returns:
Awaitable[BatchRequestOutput] if the request was handled,
None if URL didn't match
"""
if not url_matcher(request.url):
return None
handler_fn = handler_getter()
if handler_fn is None:
error_msg = f"Model does not support endpoint: {request.url}"
return make_async_error_request_output(request, error_msg=error_msg)
async def run_batch( # Apply wrapper if provided (e.g., for transcriptions/translations)
if wrapper_fn is not None:
handler_fn = wrapper_fn(handler_fn)
tracker.submitted()
return run_request(handler_fn, request, tracker)
def make_transcription_wrapper(is_translation: bool):
"""
Factory function to create a wrapper for transcription/translation handlers.
The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest
to TranscriptionRequest or TranslationRequest and calls the appropriate handler.
Args:
is_translation: If True, process as translation; otherwise process
as transcription
Returns:
A function that takes a handler and returns a wrapped handler
"""
def wrapper(handler_fn: Callable):
async def transcription_wrapper(
batch_request_body: (BatchTranscriptionRequest | BatchTranslationRequest),
) -> (
TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
| ErrorResponse
):
try:
# Download data from URL
audio_data = await download_bytes_from_url(batch_request_body.file_url)
# Create a mock file from the downloaded audio data
mock_file = UploadFile(
file=BytesIO(audio_data),
filename="audio.bin",
)
# Convert batch request to regular request
# by copying all fields except file_url and setting file to mock_file
request_dict = batch_request_body.model_dump(exclude={"file_url"})
request_dict["file"] = mock_file
if is_translation:
# Create TranslationRequest from BatchTranslationRequest
translation_request = TranslationRequest.model_validate(
request_dict
)
return await handler_fn(audio_data, translation_request)
else:
# Create TranscriptionRequest from BatchTranscriptionRequest
transcription_request = TranscriptionRequest.model_validate(
request_dict
)
return await handler_fn(audio_data, transcription_request)
except Exception as e:
operation = "translation" if is_translation else "transcription"
return ErrorResponse(
error=ErrorInfo(
message=f"Failed to process {operation}: {str(e)}",
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST.value,
)
)
return transcription_wrapper
return wrapper
def build_endpoint_registry(
engine_client: EngineClient, engine_client: EngineClient,
args: Namespace, args: Namespace,
) -> None: base_model_paths: list[BaseModelPath],
if args.served_model_name is not None: request_logger: RequestLogger | None,
served_model_names = args.served_model_name supported_tasks: tuple[SupportedTask, ...],
else: ) -> dict[str, dict[str, Any]]:
served_model_names = [args.model] """
Build the endpoint registry with all serving objects and handler configurations.
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
base_model_paths = [ Args:
BaseModelPath(name=name, model_path=args.model) for name in served_model_names engine_client: The engine client
] args: Command line arguments
base_model_paths: List of base model paths
request_logger: Optional request logger
supported_tasks: Tuple of supported tasks
Returns:
Dictionary mapping endpoint keys to their configurations
"""
model_config = engine_client.model_config model_config = engine_client.model_config
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
# Create the openai serving objects. # Create the openai serving objects.
openai_serving_models = OpenAIServingModels( openai_serving_models = OpenAIServingModels(
...@@ -507,97 +744,168 @@ async def run_batch( ...@@ -507,97 +744,168 @@ async def run_batch(
else None else None
) )
tracker = BatchProgressTracker() openai_serving_transcription = (
logger.info("Reading batch from %s...", args.input_file) OpenAIServingTranscription(
engine_client,
# Submit all requests in the file to the engine "concurrently". openai_serving_models,
response_futures: list[Awaitable[BatchRequestOutput]] = [] request_logger=request_logger,
for request_json in (await read_file(args.input_file)).strip().split("\n"): enable_force_include_usage=args.enable_force_include_usage,
# Skip empty lines. )
request_json = request_json.strip() if "transcription" in supported_tasks
if not request_json: else None
continue )
request = BatchRequestInput.model_validate_json(request_json) openai_serving_translation = (
OpenAIServingTranslation(
engine_client,
openai_serving_models,
request_logger=request_logger,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
# Determine the type of request and run it. # Registry of endpoint configurations
if request.url == "/v1/chat/completions": endpoint_registry: dict[str, dict[str, Any]] = {
chat_handler_fn = ( "completions": {
"url_matcher": lambda url: url == "/v1/chat/completions",
"handler_getter": lambda: (
openai_serving_chat.create_chat_completion openai_serving_chat.create_chat_completion
if openai_serving_chat is not None if openai_serving_chat is not None
else None else None
) ),
if chat_handler_fn is None: "wrapper_fn": None,
response_futures.append( },
make_async_error_request_output( "embeddings": {
request, "url_matcher": lambda url: url == "/v1/embeddings",
error_msg="The model does not support Chat Completions API", "handler_getter": lambda: (
)
)
continue
response_futures.append(run_request(chat_handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
embed_handler_fn = (
openai_serving_embedding.create_embedding openai_serving_embedding.create_embedding
if openai_serving_embedding is not None if openai_serving_embedding is not None
else None else None
) ),
if embed_handler_fn is None: "wrapper_fn": None,
response_futures.append( },
make_async_error_request_output( "score": {
request, "url_matcher": lambda url: url.endswith("/score"),
error_msg="The model does not support Embeddings API", "handler_getter": lambda: (
)
)
continue
response_futures.append(run_request(embed_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/score"):
score_handler_fn = (
openai_serving_scores.create_score openai_serving_scores.create_score
if openai_serving_scores is not None if openai_serving_scores is not None
else None else None
) ),
if score_handler_fn is None: "wrapper_fn": None,
response_futures.append( },
make_async_error_request_output( "rerank": {
request, "url_matcher": lambda url: url.endswith("/rerank"),
error_msg="The model does not support Scores API", "handler_getter": lambda: (
)
)
continue
response_futures.append(run_request(score_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/rerank"):
rerank_handler_fn = (
openai_serving_scores.do_rerank openai_serving_scores.do_rerank
if openai_serving_scores is not None if openai_serving_scores is not None
else None else None
),
"wrapper_fn": None,
},
"transcriptions": {
"url_matcher": lambda url: url == "/v1/audio/transcriptions",
"handler_getter": lambda: (
openai_serving_transcription.create_transcription
if openai_serving_transcription is not None
else None
),
"wrapper_fn": make_transcription_wrapper(is_translation=False),
},
"translations": {
"url_matcher": lambda url: url == "/v1/audio/translations",
"handler_getter": lambda: (
openai_serving_translation.create_translation
if openai_serving_translation is not None
else None
),
"wrapper_fn": make_transcription_wrapper(is_translation=True),
},
}
return endpoint_registry
def validate_run_batch_args(args):
valid_reasoning_parsers = ReasoningParserManager.list_registered()
if (
reasoning_parser := args.structured_outputs_config.reasoning_parser
) and reasoning_parser not in valid_reasoning_parsers:
raise KeyError(
f"invalid reasoning parser: {reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
)
async def run_batch(
engine_client: EngineClient,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
base_model_paths = [
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
endpoint_registry = build_endpoint_registry(
engine_client=engine_client,
args=args,
base_model_paths=base_model_paths,
request_logger=request_logger,
supported_tasks=supported_tasks,
)
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently".
response_futures: list[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue
request = BatchRequestInput.model_validate_json(request_json)
# Use the last segment of the URL as the endpoint key.
# More advanced URL matching is done in url_matcher of endpoint_registry.
endpoint_key = request.url.split("/")[-1]
result = None
if endpoint_key in endpoint_registry:
endpoint_config = endpoint_registry[endpoint_key]
result = handle_endpoint_request(
request,
tracker,
url_matcher=endpoint_config["url_matcher"],
handler_getter=endpoint_config["handler_getter"],
wrapper_fn=endpoint_config["wrapper_fn"],
) )
if rerank_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Rerank API",
)
)
continue
response_futures.append(run_request(rerank_handler_fn, request, tracker)) if result is not None:
tracker.submitted() response_futures.append(result)
else: else:
response_futures.append( response_futures.append(
make_async_error_request_output( make_async_error_request_output(
request, request,
error_msg=f"URL {request.url} was used. " error_msg=f"URL {request.url} was used. "
"Supported endpoints: /v1/chat/completions, /v1/embeddings," "Supported endpoints: /v1/chat/completions, /v1/embeddings,"
" /score, /rerank ." " /v1/audio/transcriptions, /v1/audio/translations, /score, "
"See vllm/entrypoints/openai/api_server.py for supported " " /rerank. See vllm/entrypoints/openai/api_server.py "
"score/rerank versions.", "for supported score/rerank versions.",
) )
) )
......
...@@ -54,7 +54,10 @@ class OpenAIServingTranscription(OpenAISpeechToText): ...@@ -54,7 +54,10 @@ class OpenAIServingTranscription(OpenAISpeechToText):
) )
async def create_transcription( async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request self,
audio_data: bytes,
request: TranscriptionRequest,
raw_request: Request | None = None,
) -> ( ) -> (
TranscriptionResponse TranscriptionResponse
| TranscriptionResponseVerbose | TranscriptionResponseVerbose
...@@ -124,7 +127,10 @@ class OpenAIServingTranslation(OpenAISpeechToText): ...@@ -124,7 +127,10 @@ class OpenAIServingTranslation(OpenAISpeechToText):
) )
async def create_translation( async def create_translation(
self, audio_data: bytes, request: TranslationRequest, raw_request: Request self,
audio_data: bytes,
request: TranslationRequest,
raw_request: Request | None = None,
) -> ( ) -> (
TranslationResponse TranslationResponse
| TranslationResponseVerbose | TranslationResponseVerbose
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment