Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
73deea2f
"docs/source/api/params.md" did not exist on "aba8d6ee006b78149ac4514f460e4038b2d4f607"
Unverified
Commit
73deea2f
authored
Mar 14, 2025
by
daniel-salib
Committed by
GitHub
Mar 14, 2025
Browse files
[Frontend] track server_load (#13950)
parent
9d2b4a70
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
131 additions
and
4 deletions
+131
-4
tests/entrypoints/openai/test_basic.py
tests/entrypoints/openai/test_basic.py
+48
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+30
-2
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+7
-0
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+46
-2
No files found.
tests/entrypoints/openai/test_basic.py
View file @
73deea2f
...
@@ -171,3 +171,51 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
...
@@ -171,3 +171,51 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
extra_headers
=
{
extra_headers
=
{
"Content-Type"
:
"application/x-www-form-urlencoded"
"Content-Type"
:
"application/x-www-form-urlencoded"
})
})
@
pytest
.
mark
.
parametrize
(
"server_args"
,
[
pytest
.
param
([
"--enable-server-load-tracking"
],
id
=
"enable-server-load-tracking"
)
],
indirect
=
True
,
)
@
pytest
.
mark
.
asyncio
async
def
test_server_load
(
server
:
RemoteOpenAIServer
):
# Check initial server load
response
=
requests
.
get
(
server
.
url_for
(
"load"
))
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
response
.
json
().
get
(
"server_load"
)
==
0
def
make_long_completion_request
():
return
requests
.
post
(
server
.
url_for
(
"v1/completions"
),
headers
=
{
"Content-Type"
:
"application/json"
},
json
=
{
"prompt"
:
"Give me a long story"
,
"max_tokens"
:
1000
,
"temperature"
:
0
,
},
)
# Start the completion request in a background thread.
completion_future
=
asyncio
.
create_task
(
asyncio
.
to_thread
(
make_long_completion_request
))
# Give a short delay to ensure the request has started.
await
asyncio
.
sleep
(
0.1
)
# Check server load while the completion request is running.
response
=
requests
.
get
(
server
.
url_for
(
"load"
))
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
response
.
json
().
get
(
"server_load"
)
==
1
# Wait for the completion request to finish.
await
completion_future
await
asyncio
.
sleep
(
0.1
)
# Check server load after the completion request has finished.
response
=
requests
.
get
(
server
.
url_for
(
"load"
))
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
response
.
json
().
get
(
"server_load"
)
==
0
vllm/entrypoints/openai/api_server.py
View file @
73deea2f
...
@@ -80,7 +80,7 @@ from vllm.entrypoints.openai.serving_tokenization import (
...
@@ -80,7 +80,7 @@ from vllm.entrypoints.openai.serving_tokenization import (
from
vllm.entrypoints.openai.serving_transcription
import
(
from
vllm.entrypoints.openai.serving_transcription
import
(
OpenAIServingTranscription
)
OpenAIServingTranscription
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.entrypoints.utils
import
with_cancellation
from
vllm.entrypoints.utils
import
load_aware_call
,
with_cancellation
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
...
@@ -347,6 +347,24 @@ async def health(raw_request: Request) -> Response:
...
@@ -347,6 +347,24 @@ async def health(raw_request: Request) -> Response:
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
@
router
.
get
(
"/load"
)
async
def
get_server_load_metrics
(
request
:
Request
):
# This endpoint returns the current server load metrics.
# It tracks requests utilizing the GPU from the following routes:
# - /v1/chat/completions
# - /v1/completions
# - /v1/audio/transcriptions
# - /v1/embeddings
# - /pooling
# - /score
# - /v1/score
# - /rerank
# - /v1/rerank
# - /v2/rerank
return
JSONResponse
(
content
=
{
'server_load'
:
request
.
app
.
state
.
server_load_metrics
})
@
router
.
api_route
(
"/ping"
,
methods
=
[
"GET"
,
"POST"
])
@
router
.
api_route
(
"/ping"
,
methods
=
[
"GET"
,
"POST"
])
async
def
ping
(
raw_request
:
Request
)
->
Response
:
async
def
ping
(
raw_request
:
Request
)
->
Response
:
"""Ping check. Endpoint required for SageMaker"""
"""Ping check. Endpoint required for SageMaker"""
...
@@ -400,6 +418,7 @@ async def show_version():
...
@@ -400,6 +418,7 @@ async def show_version():
@
router
.
post
(
"/v1/chat/completions"
,
@
router
.
post
(
"/v1/chat/completions"
,
dependencies
=
[
Depends
(
validate_json_request
)])
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
):
raw_request
:
Request
):
handler
=
chat
(
raw_request
)
handler
=
chat
(
raw_request
)
...
@@ -421,6 +440,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
...
@@ -421,6 +440,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@
router
.
post
(
"/v1/completions"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/v1/completions"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
handler
=
completion
(
raw_request
)
handler
=
completion
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
...
@@ -439,6 +459,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -439,6 +459,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@
router
.
post
(
"/v1/embeddings"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/v1/embeddings"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
handler
=
embedding
(
raw_request
)
handler
=
embedding
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
...
@@ -485,6 +506,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
...
@@ -485,6 +506,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@
router
.
post
(
"/pooling"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/pooling"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_pooling
(
request
:
PoolingRequest
,
raw_request
:
Request
):
async
def
create_pooling
(
request
:
PoolingRequest
,
raw_request
:
Request
):
handler
=
pooling
(
raw_request
)
handler
=
pooling
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
...
@@ -503,6 +525,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
...
@@ -503,6 +525,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
@
router
.
post
(
"/score"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/score"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_score
(
request
:
ScoreRequest
,
raw_request
:
Request
):
async
def
create_score
(
request
:
ScoreRequest
,
raw_request
:
Request
):
handler
=
score
(
raw_request
)
handler
=
score
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
...
@@ -521,6 +544,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
...
@@ -521,6 +544,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
@
router
.
post
(
"/v1/score"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/v1/score"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_score_v1
(
request
:
ScoreRequest
,
raw_request
:
Request
):
async
def
create_score_v1
(
request
:
ScoreRequest
,
raw_request
:
Request
):
logger
.
warning
(
logger
.
warning
(
"To indicate that Score API is not part of standard OpenAI API, we "
"To indicate that Score API is not part of standard OpenAI API, we "
...
@@ -531,10 +555,10 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
...
@@ -531,10 +555,10 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
@
router
.
post
(
"/v1/audio/transcriptions"
)
@
router
.
post
(
"/v1/audio/transcriptions"
)
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_transcriptions
(
request
:
Annotated
[
TranscriptionRequest
,
async
def
create_transcriptions
(
request
:
Annotated
[
TranscriptionRequest
,
Form
()],
Form
()],
raw_request
:
Request
):
raw_request
:
Request
):
handler
=
transcription
(
raw_request
)
handler
=
transcription
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
return
base
(
raw_request
).
create_error_response
(
...
@@ -556,6 +580,7 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
...
@@ -556,6 +580,7 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
@
router
.
post
(
"/rerank"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/rerank"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
do_rerank
(
request
:
RerankRequest
,
raw_request
:
Request
):
async
def
do_rerank
(
request
:
RerankRequest
,
raw_request
:
Request
):
handler
=
rerank
(
raw_request
)
handler
=
rerank
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
...
@@ -894,6 +919,9 @@ async def init_app_state(
...
@@ -894,6 +919,9 @@ async def init_app_state(
)
if
model_config
.
runner_type
==
"transcription"
else
None
)
if
model_config
.
runner_type
==
"transcription"
else
None
state
.
task
=
model_config
.
task
state
.
task
=
model_config
.
task
state
.
enable_server_load_tracking
=
args
.
enable_server_load_tracking
state
.
server_load_metrics
=
0
def
create_server_socket
(
addr
:
tuple
[
str
,
int
])
->
socket
.
socket
:
def
create_server_socket
(
addr
:
tuple
[
str
,
int
])
->
socket
.
socket
:
family
=
socket
.
AF_INET
family
=
socket
.
AF_INET
...
...
vllm/entrypoints/openai/cli_args.py
View file @
73deea2f
...
@@ -257,6 +257,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
...
@@ -257,6 +257,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action
=
'store_true'
,
action
=
'store_true'
,
default
=
False
,
default
=
False
,
help
=
"If set to True, enable prompt_tokens_details in usage."
)
help
=
"If set to True, enable prompt_tokens_details in usage."
)
parser
.
add_argument
(
"--enable-server-load-tracking"
,
action
=
'store_true'
,
default
=
False
,
help
=
"If set to True, enable tracking server_load_metrics in the app state."
)
return
parser
return
parser
...
...
vllm/entrypoints/utils.py
View file @
73deea2f
...
@@ -4,6 +4,8 @@ import asyncio
...
@@ -4,6 +4,8 @@ import asyncio
import
functools
import
functools
from
fastapi
import
Request
from
fastapi
import
Request
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
starlette.background
import
BackgroundTask
,
BackgroundTasks
async
def
listen_for_disconnect
(
request
:
Request
)
->
None
:
async
def
listen_for_disconnect
(
request
:
Request
)
->
None
:
...
@@ -57,3 +59,45 @@ def with_cancellation(handler_func):
...
@@ -57,3 +59,45 @@ def with_cancellation(handler_func):
return
None
return
None
return
wrapper
return
wrapper
def
decrement_server_load
(
request
:
Request
):
request
.
app
.
state
.
server_load_metrics
-=
1
def
load_aware_call
(
func
):
@
functools
.
wraps
(
func
)
async
def
wrapper
(
*
args
,
raw_request
:
Request
,
**
kwargs
):
if
not
raw_request
.
app
.
state
.
enable_server_load_tracking
:
return
await
func
(
*
args
,
raw_request
=
raw_request
,
**
kwargs
)
raw_request
.
app
.
state
.
server_load_metrics
+=
1
try
:
response
=
await
func
(
*
args
,
raw_request
=
raw_request
,
**
kwargs
)
except
Exception
:
raw_request
.
app
.
state
.
server_load_metrics
-=
1
raise
if
isinstance
(
response
,
(
JSONResponse
,
StreamingResponse
)):
if
response
.
background
is
None
:
response
.
background
=
BackgroundTask
(
decrement_server_load
,
raw_request
)
elif
isinstance
(
response
.
background
,
BackgroundTasks
):
response
.
background
.
add_task
(
decrement_server_load
,
raw_request
)
elif
isinstance
(
response
.
background
,
BackgroundTask
):
# Convert the single BackgroundTask to BackgroundTasks
# and chain the decrement_server_load task to it
tasks
=
BackgroundTasks
()
tasks
.
add_task
(
response
.
background
.
func
,
*
response
.
background
.
args
,
**
response
.
background
.
kwargs
)
tasks
.
add_task
(
decrement_server_load
,
raw_request
)
response
.
background
=
tasks
else
:
raw_request
.
app
.
state
.
server_load_metrics
-=
1
return
response
return
wrapper
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment