Unverified Commit e3b2324e authored by Pooya Davoodi's avatar Pooya Davoodi Committed by GitHub
Browse files

[Frontend] Use init_app_state and FrontendArgs in run_batch (#32967)


Signed-off-by: default avatarPooya Davoodi <pooya.davoodi@parasail.io>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent dbf0da81
...@@ -447,7 +447,7 @@ def test_metrics_exist_run_batch(): ...@@ -447,7 +447,7 @@ def test_metrics_exist_run_batch():
"--model", "--model",
"intfloat/multilingual-e5-small", "intfloat/multilingual-e5-small",
"--enable-metrics", "--enable-metrics",
"--url", "--host",
base_url, base_url,
"--port", "--port",
port, port,
......
...@@ -10,59 +10,361 @@ import pytest ...@@ -10,59 +10,361 @@ import pytest
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.entrypoints.openai.run_batch import BatchRequestOutput from vllm.entrypoints.openai.run_batch import BatchRequestOutput
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" CHAT_MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-small"
# ruff: noqa: E501 RERANKER_MODEL_NAME = "BAAI/bge-reranker-v2-m3"
INPUT_BATCH = ( REASONING_MODEL_NAME = "Qwen/Qwen3-0.6B"
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are a helpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' SPEECH_LARGE_MODEL_NAME = "openai/whisper-large-v3"
'{{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' SPEECH_SMALL_MODEL_NAME = "openai/whisper-small"
'{{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "NonExistModel", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n'
'{{"custom_id": "request-4", "method": "POST", "url": "/bad_url", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' INPUT_BATCH = "\n".join(
'{{"custom_id": "request-5", "method": "POST", "url": "/v1/chat/completions", "body": {{"stream": "True", "model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}' json.dumps(req)
).format(MODEL_NAME) for req in [
{
INVALID_INPUT_BATCH = ( "custom_id": "request-1",
'{{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are a helpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' "method": "POST",
'{{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}' "url": "/v1/chat/completions",
).format(MODEL_NAME) "body": {
"model": CHAT_MODEL_NAME,
INPUT_EMBEDDING_BATCH = ( "messages": [
'{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}\n' {
'{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are an unhelpful assistant."}}\n' "role": "system",
'{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}}\n' "content": "You are a helpful assistant.",
'{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}' },
{"role": "user", "content": "Hello world!"},
],
"max_tokens": 1000,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": CHAT_MODEL_NAME,
"messages": [
{
"role": "system",
"content": "You are an unhelpful assistant.",
},
{"role": "user", "content": "Hello world!"},
],
"max_tokens": 1000,
},
},
{
"custom_id": "request-3",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "NonExistModel",
"messages": [
{
"role": "system",
"content": "You are an unhelpful assistant.",
},
{"role": "user", "content": "Hello world!"},
],
"max_tokens": 1000,
},
},
{
"custom_id": "request-4",
"method": "POST",
"url": "/bad_url",
"body": {
"model": CHAT_MODEL_NAME,
"messages": [
{
"role": "system",
"content": "You are an unhelpful assistant.",
},
{"role": "user", "content": "Hello world!"},
],
"max_tokens": 1000,
},
},
{
"custom_id": "request-5",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"stream": "True",
"model": CHAT_MODEL_NAME,
"messages": [
{
"role": "system",
"content": "You are an unhelpful assistant.",
},
{"role": "user", "content": "Hello world!"},
],
"max_tokens": 1000,
},
},
]
) )
INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "queries": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} INVALID_INPUT_BATCH = "\n".join(
{"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "queries": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}""" json.dumps(req)
for req in [
{
"invalid_field": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": CHAT_MODEL_NAME,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello world!"},
],
"max_tokens": 1000,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": CHAT_MODEL_NAME,
"messages": [
{"role": "system", "content": "You are an unhelpful assistant."},
{"role": "user", "content": "Hello world!"},
],
"max_tokens": 1000,
},
},
]
)
INPUT_EMBEDDING_BATCH = "\n".join(
json.dumps(req)
for req in [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": EMBEDDING_MODEL_NAME,
"input": "You are a helpful assistant.",
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": EMBEDDING_MODEL_NAME,
"input": "You are an unhelpful assistant.",
},
},
{
"custom_id": "request-3",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": EMBEDDING_MODEL_NAME,
"input": "Hello world!",
},
},
{
"custom_id": "request-4",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": "NonExistModel",
"input": "Hello world!",
},
},
]
)
INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} _SCORE_RERANK_DOCUMENTS = [
{"custom_id": "request-2", "method": "POST", "url": "/v1/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} "The capital of Brazil is Brasilia.",
{"custom_id": "request-2", "method": "POST", "url": "/v2/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}""" "The capital of France is Paris.",
]
INPUT_SCORE_BATCH = "\n".join(
json.dumps(req)
for req in [
{
"custom_id": "request-1",
"method": "POST",
"url": "/score",
"body": {
"model": RERANKER_MODEL_NAME,
"queries": "What is the capital of France?",
"documents": _SCORE_RERANK_DOCUMENTS,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/score",
"body": {
"model": RERANKER_MODEL_NAME,
"queries": "What is the capital of France?",
"documents": _SCORE_RERANK_DOCUMENTS,
},
},
]
)
INPUT_RERANK_BATCH = "\n".join(
json.dumps(req)
for req in [
{
"custom_id": "request-1",
"method": "POST",
"url": "/rerank",
"body": {
"model": RERANKER_MODEL_NAME,
"query": "What is the capital of France?",
"documents": _SCORE_RERANK_DOCUMENTS,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/rerank",
"body": {
"model": RERANKER_MODEL_NAME,
"query": "What is the capital of France?",
"documents": _SCORE_RERANK_DOCUMENTS,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v2/rerank",
"body": {
"model": RERANKER_MODEL_NAME,
"query": "What is the capital of France?",
"documents": _SCORE_RERANK_DOCUMENTS,
},
},
]
)
INPUT_REASONING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Solve this math problem: 2+2=?"}]}} INPUT_REASONING_BATCH = "\n".join(
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "What is the capital of France?"}]}}""" json.dumps(req)
for req in [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": REASONING_MODEL_NAME,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Solve this math problem: 2+2=?"},
],
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": REASONING_MODEL_NAME,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
],
},
},
]
)
# This is a valid but minimal audio file for testing
MINIMAL_WAV_BASE64 = "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAEAfAAABAAgAZGF0YQAAAAA=" MINIMAL_WAV_BASE64 = "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAEAfAAABAAgAZGF0YQAAAAA="
INPUT_TRANSCRIPTION_BATCH = ( INPUT_TRANSCRIPTION_BATCH = (
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/transcriptions", ' json.dumps(
'"body": {{"model": "openai/whisper-large-v3", "file_url": "data:audio/wav;base64,{}", ' {
'"response_format": "json"}}}}\n' "custom_id": "request-1",
).format(MINIMAL_WAV_BASE64) "method": "POST",
"url": "/v1/audio/transcriptions",
"body": {
"model": SPEECH_LARGE_MODEL_NAME,
"file_url": f"data:audio/wav;base64,{MINIMAL_WAV_BASE64}",
"response_format": "json",
},
}
)
+ "\n"
)
INPUT_TRANSCRIPTION_HTTP_BATCH = ( INPUT_TRANSCRIPTION_HTTP_BATCH = (
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/transcriptions", ' json.dumps(
'"body": {{"model": "openai/whisper-large-v3", "file_url": "{}", ' {
'"response_format": "json"}}}}\n' "custom_id": "request-1",
).format(AudioAsset("mary_had_lamb").url) "method": "POST",
"url": "/v1/audio/transcriptions",
"body": {
"model": SPEECH_LARGE_MODEL_NAME,
"file_url": AudioAsset("mary_had_lamb").url,
"response_format": "json",
},
}
)
+ "\n"
)
INPUT_TRANSLATION_BATCH = ( INPUT_TRANSLATION_BATCH = (
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/translations", ' json.dumps(
'"body": {{"model": "openai/whisper-small", "file_url": "{}", ' {
'"response_format": "text", "language": "it", "to_language": "en", ' "custom_id": "request-1",
'"temperature": 0.0}}}}\n' "method": "POST",
).format(AudioAsset("mary_had_lamb").url) "url": "/v1/audio/translations",
"body": {
"model": SPEECH_SMALL_MODEL_NAME,
"file_url": AudioAsset("mary_had_lamb").url,
"response_format": "text",
"language": "it",
"to_language": "en",
"temperature": 0.0,
},
}
)
+ "\n"
)
WEATHER_TOOL = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
INPUT_TOOL_CALLING_BATCH = json.dumps(
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": REASONING_MODEL_NAME,
"messages": [
{"role": "user", "content": "What is the weather in San Francisco?"},
],
"tools": [WEATHER_TOOL],
"tool_choice": "required",
"max_tokens": 1000,
},
}
)
def test_empty_file(): def test_empty_file():
...@@ -81,7 +383,7 @@ def test_empty_file(): ...@@ -81,7 +383,7 @@ def test_empty_file():
"-o", "-o",
output_file.name, output_file.name,
"--model", "--model",
"intfloat/multilingual-e5-small", EMBEDDING_MODEL_NAME,
], ],
) )
proc.communicate() proc.communicate()
...@@ -108,7 +410,7 @@ def test_completions(): ...@@ -108,7 +410,7 @@ def test_completions():
"-o", "-o",
output_file.name, output_file.name,
"--model", "--model",
MODEL_NAME, CHAT_MODEL_NAME,
], ],
) )
proc.communicate() proc.communicate()
...@@ -141,7 +443,7 @@ def test_completions_invalid_input(): ...@@ -141,7 +443,7 @@ def test_completions_invalid_input():
"-o", "-o",
output_file.name, output_file.name,
"--model", "--model",
MODEL_NAME, CHAT_MODEL_NAME,
], ],
) )
proc.communicate() proc.communicate()
...@@ -165,7 +467,7 @@ def test_embeddings(): ...@@ -165,7 +467,7 @@ def test_embeddings():
"-o", "-o",
output_file.name, output_file.name,
"--model", "--model",
"intfloat/multilingual-e5-small", EMBEDDING_MODEL_NAME,
], ],
) )
proc.communicate() proc.communicate()
...@@ -196,7 +498,7 @@ def test_score(input_batch): ...@@ -196,7 +498,7 @@ def test_score(input_batch):
"-o", "-o",
output_file.name, output_file.name,
"--model", "--model",
"BAAI/bge-reranker-v2-m3", RERANKER_MODEL_NAME,
], ],
) )
proc.communicate() proc.communicate()
...@@ -234,7 +536,7 @@ def test_reasoning_parser(): ...@@ -234,7 +536,7 @@ def test_reasoning_parser():
"-o", "-o",
output_file.name, output_file.name,
"--model", "--model",
"Qwen/Qwen3-0.6B", REASONING_MODEL_NAME,
"--reasoning-parser", "--reasoning-parser",
"qwen3", "qwen3",
], ],
...@@ -278,7 +580,7 @@ def test_transcription(): ...@@ -278,7 +580,7 @@ def test_transcription():
"-o", "-o",
output_file.name, output_file.name,
"--model", "--model",
"openai/whisper-large-v3", SPEECH_LARGE_MODEL_NAME,
], ],
) )
proc.communicate() proc.communicate()
...@@ -316,7 +618,7 @@ def test_transcription_http_url(): ...@@ -316,7 +618,7 @@ def test_transcription_http_url():
"-o", "-o",
output_file.name, output_file.name,
"--model", "--model",
"openai/whisper-large-v3", SPEECH_LARGE_MODEL_NAME,
], ],
) )
proc.communicate() proc.communicate()
...@@ -356,7 +658,7 @@ def test_translation(): ...@@ -356,7 +658,7 @@ def test_translation():
"-o", "-o",
output_file.name, output_file.name,
"--model", "--model",
"openai/whisper-small", SPEECH_SMALL_MODEL_NAME,
], ],
) )
proc.communicate() proc.communicate()
...@@ -378,3 +680,69 @@ def test_translation(): ...@@ -378,3 +680,69 @@ def test_translation():
translation_text = response_body["text"] translation_text = response_body["text"]
translation_text_lower = str(translation_text).strip().lower() translation_text_lower = str(translation_text).strip().lower()
assert "mary" in translation_text_lower or "lamb" in translation_text_lower assert "mary" in translation_text_lower or "lamb" in translation_text_lower
def test_tool_calling():
"""
Test that tool calling works correctly in run_batch.
Verifies that requests with tools return tool_calls in the response.
"""
with (
tempfile.NamedTemporaryFile("w") as input_file,
tempfile.NamedTemporaryFile("r") as output_file,
):
input_file.write(INPUT_TOOL_CALLING_BATCH)
input_file.flush()
proc = subprocess.Popen(
[
"vllm",
"run-batch",
"-i",
input_file.name,
"-o",
output_file.name,
"--model",
REASONING_MODEL_NAME,
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
],
)
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"
contents = output_file.read()
for line in contents.strip().split("\n"):
if not line.strip(): # Skip empty lines
continue
# Ensure that the output format conforms to the openai api.
# Validation should throw if the schema is wrong.
BatchRequestOutput.model_validate_json(line)
# Ensure that there is no error in the response.
line_dict = json.loads(line)
assert isinstance(line_dict, dict)
assert line_dict["error"] is None
# Check that tool_calls are present in the response
# With tool_choice="required", the model must call a tool
response_body = line_dict["response"]["body"]
assert response_body is not None
message = response_body["choices"][0]["message"]
assert "tool_calls" in message
tool_calls = message.get("tool_calls")
# With tool_choice="required", tool_calls must be present and non-empty
assert tool_calls is not None
assert isinstance(tool_calls, list)
assert len(tool_calls) > 0
# Verify tool_calls have the expected structure
for tool_call in tool_calls:
assert "id" in tool_call
assert "type" in tool_call
assert tool_call["type"] == "function"
assert "function" in tool_call
assert "name" in tool_call["function"]
assert "arguments" in tool_call["function"]
# Verify the tool name matches our tool definition
assert tool_call["function"]["name"] == "get_current_weather"
...@@ -67,38 +67,14 @@ class LoRAParserAction(argparse.Action): ...@@ -67,38 +67,14 @@ class LoRAParserAction(argparse.Action):
@config @config
class FrontendArgs: class BaseFrontendArgs:
"""Arguments for the OpenAI-compatible frontend server.""" """Base arguments for the OpenAI-compatible frontend server.
This base class does not include host, port, and server-specific arguments
like SSL, CORS, and HTTP server settings. Those arguments are added by
the subclasses.
"""
host: str | None = None
"""Host name."""
port: int = 8000
"""Port number."""
uds: str | None = None
"""Unix domain socket path. If set, host and port arguments are ignored."""
uvicorn_log_level: Literal[
"critical", "error", "warning", "info", "debug", "trace"
] = "info"
"""Log level for uvicorn."""
disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log."""
disable_access_log_for_endpoints: str | None = None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials: bool = False
"""Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
"""Allowed origins."""
allowed_methods: list[str] = field(default_factory=lambda: ["*"])
"""Allowed methods."""
allowed_headers: list[str] = field(default_factory=lambda: ["*"])
"""Allowed headers."""
api_key: list[str] | None = None
"""If provided, the server will require one of these keys to be presented in
the header."""
lora_modules: list[LoRAModulePath] | None = None lora_modules: list[LoRAModulePath] | None = None
"""LoRA modules configurations in either 'name=path' format or JSON format """LoRA modules configurations in either 'name=path' format or JSON format
or JSON list format. Example (old format): `'name=path'` Example (new or JSON list format. Example (old format): `'name=path'` Example (new
...@@ -125,27 +101,6 @@ class FrontendArgs: ...@@ -125,27 +101,6 @@ class FrontendArgs:
to disable thinking mode by default for Qwen3/DeepSeek models.""" to disable thinking mode by default for Qwen3/DeepSeek models."""
response_role: str = "assistant" response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`.""" """The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: str | None = None
"""The file path to the SSL key file."""
ssl_certfile: str | None = None
"""The file path to the SSL cert file."""
ssl_ca_certs: str | None = None
"""The CA certificates file."""
enable_ssl_refresh: bool = False
"""Refresh SSL Context when SSL certificate files change"""
ssl_cert_reqs: int = int(ssl.CERT_NONE)
"""Whether client certificate is required (see stdlib ssl module's)."""
ssl_ciphers: str | None = None
"""SSL cipher suites for HTTPS (TLS 1.2 and below only).
Example: 'ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-CHACHA20-POLY1305'"""
root_path: str | None = None
"""FastAPI root_path when app is behind a path based routing proxy."""
middleware: list[str] = field(default_factory=lambda: [])
"""Additional ASGI middleware to apply to the app. We accept multiple
--middleware arguments. The value should be an import path. If a function
is provided, vLLM will add it to the server using
`@app.middleware('http')`. If a class is provided, vLLM will
add it to the server using `app.add_middleware()`."""
return_tokens_as_token_ids: bool = False return_tokens_as_token_ids: bool = False
"""When `--max-logprobs` is specified, represents single tokens as """When `--max-logprobs` is specified, represents single tokens as
strings of the form 'token_id:{token_id}' so that tokens that are not strings of the form 'token_id:{token_id}' so that tokens that are not
...@@ -153,8 +108,6 @@ class FrontendArgs: ...@@ -153,8 +108,6 @@ class FrontendArgs:
disable_frontend_multiprocessing: bool = False disable_frontend_multiprocessing: bool = False
"""If specified, will run the OpenAI frontend server in the same process as """If specified, will run the OpenAI frontend server in the same process as
the model serving engine.""" the model serving engine."""
enable_request_id_headers: bool = False
"""If specified, API server will add X-Request-Id header to responses."""
enable_auto_tool_choice: bool = False enable_auto_tool_choice: bool = False
"""Enable auto tool choice for supported models. Use `--tool-call-parser` """Enable auto tool choice for supported models. Use `--tool-call-parser`
to specify which parser to use.""" to specify which parser to use."""
...@@ -179,8 +132,6 @@ class FrontendArgs: ...@@ -179,8 +132,6 @@ class FrontendArgs:
max_log_len: int | None = None max_log_len: int | None = None
"""Max number of prompt characters or prompt ID numbers being printed in """Max number of prompt characters or prompt ID numbers being printed in
log. The default of None means unlimited.""" log. The default of None means unlimited."""
disable_fastapi_docs: bool = False
"""Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."""
enable_prompt_tokens_details: bool = False enable_prompt_tokens_details: bool = False
"""If set to True, enable prompt_tokens_details in usage.""" """If set to True, enable prompt_tokens_details in usage."""
enable_server_load_tracking: bool = False enable_server_load_tracking: bool = False
...@@ -197,12 +148,6 @@ class FrontendArgs: ...@@ -197,12 +148,6 @@ class FrontendArgs:
"""If set to False, output deltas will not be logged. Relevant only if """If set to False, output deltas will not be logged. Relevant only if
--enable-log-outputs is set. --enable-log-outputs is set.
""" """
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
"""Maximum number of HTTP headers allowed in a request for h11 parser.
Helps mitigate header abuse. Default: 256."""
log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE
"""If set to True, log the stack trace of error responses""" """If set to True, log the stack trace of error responses"""
tokens_only: bool = False tokens_only: bool = False
...@@ -210,17 +155,135 @@ class FrontendArgs: ...@@ -210,17 +155,135 @@ class FrontendArgs:
If set to True, only enable the Tokens In<>Out endpoint. If set to True, only enable the Tokens In<>Out endpoint.
This is intended for use in a Disaggregated Everything setup. This is intended for use in a Disaggregated Everything setup.
""" """
@classmethod
def _customize_cli_kwargs(
cls,
frontend_kwargs: dict[str, Any],
) -> dict[str, Any]:
"""Customize argparse kwargs before arguments are registered.
Subclasses should override this and call
``super()._customize_cli_kwargs(frontend_kwargs)`` first.
"""
# Special case: default_chat_template_kwargs needs json.loads type
frontend_kwargs["default_chat_template_kwargs"]["type"] = json.loads
# Special case: LoRA modules need custom parser action and
# optional_type(str)
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
# Special case: Tool call parser shows built-in options.
valid_tool_parsers = list(ToolParserManager.list_registered())
parsers_str = ",".join(valid_tool_parsers)
frontend_kwargs["tool_call_parser"]["metavar"] = (
f"{{{parsers_str}}} or name registered in --tool-parser-plugin"
)
return frontend_kwargs
@classmethod
def add_cli_args(cls, parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Register CLI arguments for this frontend class.
Subclasses should override ``_customize_cli_kwargs`` instead of
this method so that base-class postprocessing is always applied.
"""
from vllm.engine.arg_utils import get_kwargs
frontend_kwargs = get_kwargs(cls)
frontend_kwargs = cls._customize_cli_kwargs(frontend_kwargs)
group_name = cls.__name__.replace("Args", "")
frontend_group = parser.add_argument_group(
title=group_name,
description=cls.__doc__,
)
for key, value in frontend_kwargs.items():
extra_flags = value.pop("flags", [])
frontend_group.add_argument(
*extra_flags, f"--{key.replace('_', '-')}", **value
)
return parser
@config
class FrontendArgs(BaseFrontendArgs):
"""Arguments for the OpenAI-compatible frontend server."""
host: str | None = None
"""Host name."""
port: int = 8000
"""Port number."""
uds: str | None = None
"""Unix domain socket path. If set, host and port arguments are ignored."""
uvicorn_log_level: Literal[
"critical", "error", "warning", "info", "debug", "trace"
] = "info"
"""Log level for uvicorn."""
disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log."""
disable_access_log_for_endpoints: str | None = None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials: bool = False
"""Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
"""Allowed origins."""
allowed_methods: list[str] = field(default_factory=lambda: ["*"])
"""Allowed methods."""
allowed_headers: list[str] = field(default_factory=lambda: ["*"])
"""Allowed headers."""
api_key: list[str] | None = None
"""If provided, the server will require one of these keys to be presented in
the header."""
ssl_keyfile: str | None = None
"""The file path to the SSL key file."""
ssl_certfile: str | None = None
"""The file path to the SSL cert file."""
ssl_ca_certs: str | None = None
"""The CA certificates file."""
enable_ssl_refresh: bool = False
"""Refresh SSL Context when SSL certificate files change"""
ssl_cert_reqs: int = int(ssl.CERT_NONE)
"""Whether client certificate is required (see stdlib ssl module's)."""
ssl_ciphers: str | None = None
"""SSL cipher suites for HTTPS (TLS 1.2 and below only).
Example: 'ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-CHACHA20-POLY1305'"""
root_path: str | None = None
"""FastAPI root_path when app is behind a path based routing proxy."""
middleware: list[str] = field(default_factory=lambda: [])
"""Additional ASGI middleware to apply to the app. We accept multiple
--middleware arguments. The value should be an import path. If a function
is provided, vLLM will add it to the server using
`@app.middleware('http')`. If a class is provided, vLLM will
add it to the server using `app.add_middleware()`."""
enable_request_id_headers: bool = False
"""If specified, API server will add X-Request-Id header to responses."""
disable_fastapi_docs: bool = False
"""Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint."""
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
"""Maximum number of HTTP headers allowed in a request for h11 parser.
Helps mitigate header abuse. Default: 256."""
enable_offline_docs: bool = False enable_offline_docs: bool = False
""" """
Enable offline FastAPI documentation for air-gapped environments. Enable offline FastAPI documentation for air-gapped environments.
Uses vendored static assets bundled with vLLM. Uses vendored static assets bundled with vLLM.
""" """
@staticmethod @classmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def _customize_cli_kwargs(
from vllm.engine.arg_utils import get_kwargs cls,
frontend_kwargs: dict[str, Any],
frontend_kwargs = get_kwargs(FrontendArgs) ) -> dict[str, Any]:
frontend_kwargs = super()._customize_cli_kwargs(frontend_kwargs)
# Special case: allowed_origins, allowed_methods, allowed_headers all # Special case: allowed_origins, allowed_methods, allowed_headers all
# need json.loads type # need json.loads type
...@@ -232,14 +295,6 @@ class FrontendArgs: ...@@ -232,14 +295,6 @@ class FrontendArgs:
del frontend_kwargs["allowed_methods"]["nargs"] del frontend_kwargs["allowed_methods"]["nargs"]
del frontend_kwargs["allowed_headers"]["nargs"] del frontend_kwargs["allowed_headers"]["nargs"]
# Special case: default_chat_template_kwargs needs json.loads type
frontend_kwargs["default_chat_template_kwargs"]["type"] = json.loads
# Special case: LoRA modules need custom parser action and
# optional_type(str)
frontend_kwargs["lora_modules"]["type"] = optional_type(str)
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
# Special case: Middleware needs to append action # Special case: Middleware needs to append action
frontend_kwargs["middleware"]["action"] = "append" frontend_kwargs["middleware"]["action"] = "append"
frontend_kwargs["middleware"]["type"] = str frontend_kwargs["middleware"]["type"] = str
...@@ -252,22 +307,7 @@ class FrontendArgs: ...@@ -252,22 +307,7 @@ class FrontendArgs:
if "nargs" in frontend_kwargs["disable_access_log_for_endpoints"]: if "nargs" in frontend_kwargs["disable_access_log_for_endpoints"]:
del frontend_kwargs["disable_access_log_for_endpoints"]["nargs"] del frontend_kwargs["disable_access_log_for_endpoints"]["nargs"]
# Special case: Tool call parser shows built-in options. return frontend_kwargs
valid_tool_parsers = list(ToolParserManager.list_registered())
parsers_str = ",".join(valid_tool_parsers)
frontend_kwargs["tool_call_parser"]["metavar"] = (
f"{{{parsers_str}}} or name registered in --tool-parser-plugin"
)
frontend_group = parser.add_argument_group(
title="Frontend",
description=FrontendArgs.__doc__,
)
for key, value in frontend_kwargs.items():
frontend_group.add_argument(f"--{key.replace('_', '-')}", **value)
return parser
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import asyncio import asyncio
import base64 import base64
import sys
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
...@@ -17,23 +18,23 @@ from fastapi import UploadFile ...@@ -17,23 +18,23 @@ from fastapi import UploadFile
from prometheus_client import start_http_server from prometheus_client import start_http_server
from pydantic import Field, TypeAdapter, field_validator, model_validator from pydantic import Field, TypeAdapter, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from starlette.datastructures import State
from tqdm import tqdm from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.api_server import init_app_state
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
) )
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.cli_args import BaseFrontendArgs
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo, ErrorInfo,
ErrorResponse, ErrorResponse,
OpenAIBaseModel, OpenAIBaseModel,
) )
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.speech_to_text.protocol import ( from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponse,
...@@ -42,25 +43,18 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( ...@@ -42,25 +43,18 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationResponse, TranslationResponse,
TranslationResponseVerbose, TranslationResponseVerbose,
) )
from vllm.entrypoints.openai.speech_to_text.serving import (
OpenAIServingTranscription,
OpenAIServingTranslation,
)
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingRequest, EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
) )
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.score.protocol import ( from vllm.entrypoints.pooling.score.protocol import (
RerankRequest, RerankRequest,
RerankResponse, RerankResponse,
ScoreRequest, ScoreRequest,
ScoreResponse, ScoreResponse,
) )
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.tasks import SupportedTask
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
...@@ -219,87 +213,73 @@ class BatchRequestOutput(OpenAIBaseModel): ...@@ -219,87 +213,73 @@ class BatchRequestOutput(OpenAIBaseModel):
error: Any | None error: Any | None
def make_arg_parser(parser: FlexibleArgumentParser): @config
parser.add_argument( class BatchFrontendArgs(BaseFrontendArgs):
"-i", """Arguments for the batch runner frontend."""
"--input-file",
required=True, input_file: str | None = None
type=str, """The path or url to a single input file. Currently supports local file
help="The path or url to a single input file. Currently supports local file " paths, or the http protocol (http or https). If a URL is specified,
"paths, or the http protocol (http or https). If a URL is specified, " the file should be available via HTTP GET."""
"the file should be available via HTTP GET.", output_file: str | None = None
) """The path or url to a single output file. Currently supports
parser.add_argument( local file paths, or web (http or https) urls. If a URL is specified,
"-o", the file should be available via HTTP PUT."""
"--output-file", output_tmp_dir: str | None = None
required=True, """The directory to store the output file before uploading it
type=str, to the output URL."""
help="The path or url to a single output file. Currently supports " enable_metrics: bool = False
"local file paths, or web (http or https) urls. If a URL is specified," """Enable Prometheus metrics"""
" the file should be available via HTTP PUT.", host: str | None = None
) """Host name for the Prometheus metrics server
parser.add_argument( (only needed if enable-metrics is set)."""
"--output-tmp-dir", port: int = 8000
type=str, """Port number for the Prometheus metrics server
default=None, (only needed if enable-metrics is set)."""
help="The directory to store the output file before uploading it " url: str = "0.0.0.0"
"to the output URL.", """[DEPRECATED] Host name for the Prometheus metrics server
) (only needed if enable-metrics is set). Use --host instead."""
parser.add_argument(
"--response-role",
type=optional_type(str),
default="assistant",
help="The role name to return if `request.add_generation_prompt=True`.",
)
parser = AsyncEngineArgs.add_cli_args(parser) @classmethod
def _customize_cli_kwargs(
cls,
frontend_kwargs: dict[str, Any],
) -> dict[str, Any]:
frontend_kwargs = super()._customize_cli_kwargs(frontend_kwargs)
parser.add_argument( frontend_kwargs["input_file"]["flags"] = ["-i"]
"--max-log-len", frontend_kwargs["input_file"]["required"] = True
type=int, frontend_kwargs["output_file"]["flags"] = ["-o"]
default=None, frontend_kwargs["output_file"]["required"] = True
help="Max number of prompt characters or prompt "
"ID numbers being printed in log."
"\n\nDefault: Unlimited",
)
parser.add_argument( frontend_kwargs["enable_metrics"]["action"] = "store_true"
"--enable-metrics", action="store_true", help="Enable Prometheus metrics"
)
parser.add_argument(
"--url",
type=str,
default="0.0.0.0",
help="URL to the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--enable-prompt-tokens-details",
action="store_true",
default=False,
help="If set to True, enable prompt_tokens_details in usage.",
)
parser.add_argument(
"--enable-force-include-usage",
action="store_true",
default=False,
help="If set to True, include usage on every request "
"(even when stream_options is not specified)",
)
frontend_kwargs["url"]["deprecated"] = True
return frontend_kwargs
def make_arg_parser(parser: FlexibleArgumentParser):
parser = BatchFrontendArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
return parser return parser
def parse_args(): def parse_args():
parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.") parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.")
return make_arg_parser(parser).parse_args() args = make_arg_parser(parser).parse_args()
# Backward compatibility: If --url is set, use it for host
url_explicit = any(arg == "--url" or arg.startswith("--url=") for arg in sys.argv)
host_explicit = any(
arg == "--host" or arg.startswith("--host=") for arg in sys.argv
)
if url_explicit and hasattr(args, "url") and not host_explicit:
args.host = args.url
logger.warning_once(
"Using --url for metrics is deprecated. Please use --host instead."
)
return args
# explicitly use pure text format, with a newline at the end # explicitly use pure text format, with a newline at the end
...@@ -671,12 +651,9 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn: ...@@ -671,12 +651,9 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
return wrapper return wrapper
def build_endpoint_registry( async def build_endpoint_registry(
engine_client: EngineClient, engine_client: EngineClient,
args: Namespace, args: Namespace,
base_model_paths: list[BaseModelPath],
request_logger: RequestLogger | None,
supported_tasks: tuple[SupportedTask, ...],
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
""" """
Build the endpoint registry with all serving objects and handler configurations. Build the endpoint registry with all serving objects and handler configurations.
...@@ -684,90 +661,27 @@ def build_endpoint_registry( ...@@ -684,90 +661,27 @@ def build_endpoint_registry(
Args: Args:
engine_client: The engine client engine_client: The engine client
args: Command line arguments args: Command line arguments
base_model_paths: List of base model paths
request_logger: Optional request logger
supported_tasks: Tuple of supported tasks
Returns: Returns:
Dictionary mapping endpoint keys to their configurations Dictionary mapping endpoint keys to their configurations
""" """
model_config = engine_client.model_config supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
base_model_paths=base_model_paths,
lora_modules=None,
)
openai_serving_chat = (
OpenAIServingChat(
engine_client,
openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
default_chat_template_kwargs=getattr(
args, "default_chat_template_kwargs", None
),
)
if "generate" in supported_tasks
else None
)
openai_serving_embedding = ( # Create a state object to hold serving objects
OpenAIServingEmbedding( state = State()
engine_client,
openai_serving_models,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
)
if "embed" in supported_tasks
else None
)
enable_serving_reranking = ( # Initialize all serving objects using init_app_state
"classify" in supported_tasks # This provides full functionality including chat template processing,
and getattr(model_config.hf_config, "num_labels", 0) == 1 # LoRA support, tool servers, etc.
) await init_app_state(engine_client, state, args, supported_tasks)
openai_serving_scores = ( # Get serving objects from state (defaulting to None if not set)
ServingScores( openai_serving_chat = getattr(state, "openai_serving_chat", None)
engine_client, openai_serving_embedding = getattr(state, "openai_serving_embedding", None)
openai_serving_models, openai_serving_scores = getattr(state, "openai_serving_scores", None)
request_logger=request_logger, openai_serving_transcription = getattr(state, "openai_serving_transcription", None)
score_template=None, openai_serving_translation = getattr(state, "openai_serving_translation", None)
)
if ("embed" in supported_tasks or enable_serving_reranking)
else None
)
openai_serving_transcription = (
OpenAIServingTranscription(
engine_client,
openai_serving_models,
request_logger=request_logger,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
openai_serving_translation = (
OpenAIServingTranslation(
engine_client,
openai_serving_models,
request_logger=request_logger,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
# Registry of endpoint configurations # Registry of endpoint configurations
endpoint_registry: dict[str, dict[str, Any]] = { endpoint_registry: dict[str, dict[str, Any]] = {
...@@ -845,29 +759,9 @@ async def run_batch( ...@@ -845,29 +759,9 @@ async def run_batch(
engine_client: EngineClient, engine_client: EngineClient,
args: Namespace, args: Namespace,
) -> None: ) -> None:
if args.served_model_name is not None: endpoint_registry = await build_endpoint_registry(
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
base_model_paths = [
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
endpoint_registry = build_endpoint_registry(
engine_client=engine_client, engine_client=engine_client,
args=args, args=args,
base_model_paths=base_model_paths,
request_logger=request_logger,
supported_tasks=supported_tasks,
) )
tracker = BatchProgressTracker() tracker = BatchProgressTracker()
...@@ -942,7 +836,7 @@ if __name__ == "__main__": ...@@ -942,7 +836,7 @@ if __name__ == "__main__":
# to publish metrics at the /metrics endpoint. # to publish metrics at the /metrics endpoint.
if args.enable_metrics: if args.enable_metrics:
logger.info("Prometheus metrics enabled") logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.url) start_http_server(port=args.port, addr=args.host)
else: else:
logger.info("Prometheus metrics disabled") logger.info("Prometheus metrics disabled")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment