Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
73deea2f
Unverified
Commit
73deea2f
authored
Mar 14, 2025
by
daniel-salib
Committed by
GitHub
Mar 14, 2025
Browse files
[Frontend] track server_load (#13950)
parent
9d2b4a70
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
131 additions
and
4 deletions
+131
-4
tests/entrypoints/openai/test_basic.py
tests/entrypoints/openai/test_basic.py
+48
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+30
-2
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+7
-0
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+46
-2
No files found.
tests/entrypoints/openai/test_basic.py
View file @
73deea2f
...
@@ -171,3 +171,51 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
...
@@ -171,3 +171,51 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
extra_headers
=
{
extra_headers
=
{
"Content-Type"
:
"application/x-www-form-urlencoded"
"Content-Type"
:
"application/x-www-form-urlencoded"
})
})
@
pytest
.
mark
.
parametrize
(
"server_args"
,
[
pytest
.
param
([
"--enable-server-load-tracking"
],
id
=
"enable-server-load-tracking"
)
],
indirect
=
True
,
)
@
pytest
.
mark
.
asyncio
async
def
test_server_load
(
server
:
RemoteOpenAIServer
):
# Check initial server load
response
=
requests
.
get
(
server
.
url_for
(
"load"
))
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
response
.
json
().
get
(
"server_load"
)
==
0
def
make_long_completion_request
():
return
requests
.
post
(
server
.
url_for
(
"v1/completions"
),
headers
=
{
"Content-Type"
:
"application/json"
},
json
=
{
"prompt"
:
"Give me a long story"
,
"max_tokens"
:
1000
,
"temperature"
:
0
,
},
)
# Start the completion request in a background thread.
completion_future
=
asyncio
.
create_task
(
asyncio
.
to_thread
(
make_long_completion_request
))
# Give a short delay to ensure the request has started.
await
asyncio
.
sleep
(
0.1
)
# Check server load while the completion request is running.
response
=
requests
.
get
(
server
.
url_for
(
"load"
))
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
response
.
json
().
get
(
"server_load"
)
==
1
# Wait for the completion request to finish.
await
completion_future
await
asyncio
.
sleep
(
0.1
)
# Check server load after the completion request has finished.
response
=
requests
.
get
(
server
.
url_for
(
"load"
))
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
response
.
json
().
get
(
"server_load"
)
==
0
vllm/entrypoints/openai/api_server.py
View file @
73deea2f
...
@@ -80,7 +80,7 @@ from vllm.entrypoints.openai.serving_tokenization import (
...
@@ -80,7 +80,7 @@ from vllm.entrypoints.openai.serving_tokenization import (
from
vllm.entrypoints.openai.serving_transcription
import
(
from
vllm.entrypoints.openai.serving_transcription
import
(
OpenAIServingTranscription
)
OpenAIServingTranscription
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.entrypoints.utils
import
with_cancellation
from
vllm.entrypoints.utils
import
load_aware_call
,
with_cancellation
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
...
@@ -347,6 +347,24 @@ async def health(raw_request: Request) -> Response:
...
@@ -347,6 +347,24 @@ async def health(raw_request: Request) -> Response:
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
@
router
.
get
(
"/load"
)
async
def
get_server_load_metrics
(
request
:
Request
):
# This endpoint returns the current server load metrics.
# It tracks requests utilizing the GPU from the following routes:
# - /v1/chat/completions
# - /v1/completions
# - /v1/audio/transcriptions
# - /v1/embeddings
# - /pooling
# - /score
# - /v1/score
# - /rerank
# - /v1/rerank
# - /v2/rerank
return
JSONResponse
(
content
=
{
'server_load'
:
request
.
app
.
state
.
server_load_metrics
})
@
router
.
api_route
(
"/ping"
,
methods
=
[
"GET"
,
"POST"
])
@
router
.
api_route
(
"/ping"
,
methods
=
[
"GET"
,
"POST"
])
async
def
ping
(
raw_request
:
Request
)
->
Response
:
async
def
ping
(
raw_request
:
Request
)
->
Response
:
"""Ping check. Endpoint required for SageMaker"""
"""Ping check. Endpoint required for SageMaker"""
...
@@ -400,6 +418,7 @@ async def show_version():
...
@@ -400,6 +418,7 @@ async def show_version():
@
router
.
post
(
"/v1/chat/completions"
,
@
router
.
post
(
"/v1/chat/completions"
,
dependencies
=
[
Depends
(
validate_json_request
)])
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
):
raw_request
:
Request
):
handler
=
chat
(
raw_request
)
handler
=
chat
(
raw_request
)
...
@@ -421,6 +440,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
...
@@ -421,6 +440,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@
router
.
post
(
"/v1/completions"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/v1/completions"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
handler
=
completion
(
raw_request
)
handler
=
completion
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
...
@@ -439,6 +459,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -439,6 +459,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@
router
.
post
(
"/v1/embeddings"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/v1/embeddings"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
handler
=
embedding
(
raw_request
)
handler
=
embedding
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
...
@@ -485,6 +506,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
...
@@ -485,6 +506,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@
router
.
post
(
"/pooling"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/pooling"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_pooling
(
request
:
PoolingRequest
,
raw_request
:
Request
):
async
def
create_pooling
(
request
:
PoolingRequest
,
raw_request
:
Request
):
handler
=
pooling
(
raw_request
)
handler
=
pooling
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
...
@@ -503,6 +525,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
...
@@ -503,6 +525,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
@
router
.
post
(
"/score"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/score"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_score
(
request
:
ScoreRequest
,
raw_request
:
Request
):
async
def
create_score
(
request
:
ScoreRequest
,
raw_request
:
Request
):
handler
=
score
(
raw_request
)
handler
=
score
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
...
@@ -521,6 +544,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
...
@@ -521,6 +544,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
@
router
.
post
(
"/v1/score"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/v1/score"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_score_v1
(
request
:
ScoreRequest
,
raw_request
:
Request
):
async
def
create_score_v1
(
request
:
ScoreRequest
,
raw_request
:
Request
):
logger
.
warning
(
logger
.
warning
(
"To indicate that Score API is not part of standard OpenAI API, we "
"To indicate that Score API is not part of standard OpenAI API, we "
...
@@ -531,10 +555,10 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
...
@@ -531,10 +555,10 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
@
router
.
post
(
"/v1/audio/transcriptions"
)
@
router
.
post
(
"/v1/audio/transcriptions"
)
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
create_transcriptions
(
request
:
Annotated
[
TranscriptionRequest
,
async
def
create_transcriptions
(
request
:
Annotated
[
TranscriptionRequest
,
Form
()],
Form
()],
raw_request
:
Request
):
raw_request
:
Request
):
handler
=
transcription
(
raw_request
)
handler
=
transcription
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
return
base
(
raw_request
).
create_error_response
(
...
@@ -556,6 +580,7 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
...
@@ -556,6 +580,7 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
@
router
.
post
(
"/rerank"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
router
.
post
(
"/rerank"
,
dependencies
=
[
Depends
(
validate_json_request
)])
@
with_cancellation
@
with_cancellation
@
load_aware_call
async
def
do_rerank
(
request
:
RerankRequest
,
raw_request
:
Request
):
async
def
do_rerank
(
request
:
RerankRequest
,
raw_request
:
Request
):
handler
=
rerank
(
raw_request
)
handler
=
rerank
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
...
@@ -894,6 +919,9 @@ async def init_app_state(
...
@@ -894,6 +919,9 @@ async def init_app_state(
)
if
model_config
.
runner_type
==
"transcription"
else
None
)
if
model_config
.
runner_type
==
"transcription"
else
None
state
.
task
=
model_config
.
task
state
.
task
=
model_config
.
task
state
.
enable_server_load_tracking
=
args
.
enable_server_load_tracking
state
.
server_load_metrics
=
0
def
create_server_socket
(
addr
:
tuple
[
str
,
int
])
->
socket
.
socket
:
def
create_server_socket
(
addr
:
tuple
[
str
,
int
])
->
socket
.
socket
:
family
=
socket
.
AF_INET
family
=
socket
.
AF_INET
...
...
vllm/entrypoints/openai/cli_args.py
View file @
73deea2f
...
@@ -257,6 +257,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
...
@@ -257,6 +257,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action
=
'store_true'
,
action
=
'store_true'
,
default
=
False
,
default
=
False
,
help
=
"If set to True, enable prompt_tokens_details in usage."
)
help
=
"If set to True, enable prompt_tokens_details in usage."
)
parser
.
add_argument
(
"--enable-server-load-tracking"
,
action
=
'store_true'
,
default
=
False
,
help
=
"If set to True, enable tracking server_load_metrics in the app state."
)
return
parser
return
parser
...
...
vllm/entrypoints/utils.py
View file @
73deea2f
...
@@ -4,6 +4,8 @@ import asyncio
...
@@ -4,6 +4,8 @@ import asyncio
import
functools
import
functools
from
fastapi
import
Request
from
fastapi
import
Request
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
starlette.background
import
BackgroundTask
,
BackgroundTasks
async
def
listen_for_disconnect
(
request
:
Request
)
->
None
:
async
def
listen_for_disconnect
(
request
:
Request
)
->
None
:
...
@@ -17,9 +19,9 @@ async def listen_for_disconnect(request: Request) -> None:
...
@@ -17,9 +19,9 @@ async def listen_for_disconnect(request: Request) -> None:
def
with_cancellation
(
handler_func
):
def
with_cancellation
(
handler_func
):
"""Decorator that allows a route handler to be cancelled by client
"""Decorator that allows a route handler to be cancelled by client
disconnections.
disconnections.
This does _not_ use request.is_disconnected, which does not work with
This does _not_ use request.is_disconnected, which does not work with
middleware. Instead this follows the pattern from
middleware. Instead this follows the pattern from
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
to wait for an http disconnect message, and the other to do the work that we
to wait for an http disconnect message, and the other to do the work that we
want done. When the first task finishes, the other is cancelled.
want done. When the first task finishes, the other is cancelled.
...
@@ -57,3 +59,45 @@ def with_cancellation(handler_func):
...
@@ -57,3 +59,45 @@ def with_cancellation(handler_func):
return
None
return
None
return
wrapper
return
wrapper
def
decrement_server_load
(
request
:
Request
):
request
.
app
.
state
.
server_load_metrics
-=
1
def
load_aware_call
(
func
):
@
functools
.
wraps
(
func
)
async
def
wrapper
(
*
args
,
raw_request
:
Request
,
**
kwargs
):
if
not
raw_request
.
app
.
state
.
enable_server_load_tracking
:
return
await
func
(
*
args
,
raw_request
=
raw_request
,
**
kwargs
)
raw_request
.
app
.
state
.
server_load_metrics
+=
1
try
:
response
=
await
func
(
*
args
,
raw_request
=
raw_request
,
**
kwargs
)
except
Exception
:
raw_request
.
app
.
state
.
server_load_metrics
-=
1
raise
if
isinstance
(
response
,
(
JSONResponse
,
StreamingResponse
)):
if
response
.
background
is
None
:
response
.
background
=
BackgroundTask
(
decrement_server_load
,
raw_request
)
elif
isinstance
(
response
.
background
,
BackgroundTasks
):
response
.
background
.
add_task
(
decrement_server_load
,
raw_request
)
elif
isinstance
(
response
.
background
,
BackgroundTask
):
# Convert the single BackgroundTask to BackgroundTasks
# and chain the decrement_server_load task to it
tasks
=
BackgroundTasks
()
tasks
.
add_task
(
response
.
background
.
func
,
*
response
.
background
.
args
,
**
response
.
background
.
kwargs
)
tasks
.
add_task
(
decrement_server_load
,
raw_request
)
response
.
background
=
tasks
else
:
raw_request
.
app
.
state
.
server_load_metrics
-=
1
return
response
return
wrapper
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment