Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3f42b05f
Unverified
Commit
3f42b05f
authored
Dec 03, 2025
by
Chauncey
Committed by
GitHub
Dec 03, 2025
Browse files
[Refactor] [1/N] to simplify the vLLM serving architecture (#28040)
Signed-off-by:
chaunceyjiang
<
chaunceyjiang@gmail.com
>
parent
69520bc6
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
570 additions
and
455 deletions
+570
-455
tests/entrypoints/openai/test_basic.py
tests/entrypoints/openai/test_basic.py
+1
-1
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+1
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+10
-445
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+1
-2
vllm/entrypoints/sagemaker/routes.py
vllm/entrypoints/sagemaker/routes.py
+1
-1
vllm/entrypoints/serve/__init__.py
vllm/entrypoints/serve/__init__.py
+60
-0
vllm/entrypoints/serve/disagg/__init__.py
vllm/entrypoints/serve/disagg/__init__.py
+0
-0
vllm/entrypoints/serve/disagg/api_router.py
vllm/entrypoints/serve/disagg/api_router.py
+110
-0
vllm/entrypoints/serve/disagg/protocol.py
vllm/entrypoints/serve/disagg/protocol.py
+90
-0
vllm/entrypoints/serve/disagg/serving.py
vllm/entrypoints/serve/disagg/serving.py
+7
-3
vllm/entrypoints/serve/elastic_ep/__init__.py
vllm/entrypoints/serve/elastic_ep/__init__.py
+0
-0
vllm/entrypoints/serve/elastic_ep/api_router.py
vllm/entrypoints/serve/elastic_ep/api_router.py
+96
-0
vllm/entrypoints/serve/elastic_ep/middleware.py
vllm/entrypoints/serve/elastic_ep/middleware.py
+49
-0
vllm/entrypoints/serve/instrumentator/__init__.py
vllm/entrypoints/serve/instrumentator/__init__.py
+0
-0
vllm/entrypoints/serve/instrumentator/health.py
vllm/entrypoints/serve/instrumentator/health.py
+33
-0
vllm/entrypoints/serve/instrumentator/metrics.py
vllm/entrypoints/serve/instrumentator/metrics.py
+46
-0
vllm/entrypoints/serve/lora/__init__.py
vllm/entrypoints/serve/lora/__init__.py
+0
-0
vllm/entrypoints/serve/lora/api_router.py
vllm/entrypoints/serve/lora/api_router.py
+16
-3
vllm/entrypoints/serve/profile/__init__.py
vllm/entrypoints/serve/profile/__init__.py
+0
-0
vllm/entrypoints/serve/profile/api_router.py
vllm/entrypoints/serve/profile/api_router.py
+49
-0
No files found.
tests/entrypoints/openai/test_basic.py
View file @
3f42b05f
...
@@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer):
...
@@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_health_check_engine_dead_error
():
async
def
test_health_check_engine_dead_error
():
# Import the health function directly to test it in isolation
# Import the health function directly to test it in isolation
from
vllm.entrypoints.
openai.api_server
import
health
from
vllm.entrypoints.
serve.instrumentator.health
import
health
# Create a mock request that simulates what FastAPI would provide
# Create a mock request that simulates what FastAPI would provide
mock_request
=
Mock
(
spec
=
Request
)
mock_request
=
Mock
(
spec
=
Request
)
...
...
vllm/entrypoints/api_server.py
View file @
3f42b05f
...
@@ -118,6 +118,7 @@ async def init_app(
...
@@ -118,6 +118,7 @@ async def init_app(
)
)
)
)
app
.
state
.
engine_client
=
engine
app
.
state
.
engine_client
=
engine
app
.
state
.
args
=
args
return
app
return
app
...
...
vllm/entrypoints/openai/api_server.py
View file @
3f42b05f
...
@@ -20,21 +20,15 @@ from http import HTTPStatus
...
@@ -20,21 +20,15 @@ from http import HTTPStatus
from
typing
import
Annotated
,
Any
,
Literal
from
typing
import
Annotated
,
Any
,
Literal
import
model_hosting_container_standards.sagemaker
as
sagemaker_standards
import
model_hosting_container_standards.sagemaker
as
sagemaker_standards
import
prometheus_client
import
pydantic
import
pydantic
import
regex
as
re
import
uvloop
import
uvloop
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
Form
,
HTTPException
,
Query
,
Request
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
Form
,
HTTPException
,
Query
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
prometheus_client
import
make_asgi_app
from
prometheus_fastapi_instrumentator
import
Instrumentator
from
starlette.concurrency
import
iterate_in_threadpool
from
starlette.concurrency
import
iterate_in_threadpool
from
starlette.datastructures
import
URL
,
Headers
,
MutableHeaders
,
State
from
starlette.datastructures
import
URL
,
Headers
,
MutableHeaders
,
State
from
starlette.routing
import
Mount
from
starlette.types
import
ASGIApp
,
Message
,
Receive
,
Scope
,
Send
from
starlette.types
import
ASGIApp
,
Message
,
Receive
,
Scope
,
Send
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -56,17 +50,11 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -56,17 +50,11 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionResponse
,
ChatCompletionResponse
,
CompletionRequest
,
CompletionRequest
,
CompletionResponse
,
CompletionResponse
,
DetokenizeRequest
,
DetokenizeResponse
,
ErrorInfo
,
ErrorInfo
,
ErrorResponse
,
ErrorResponse
,
GenerateRequest
,
GenerateResponse
,
ResponsesRequest
,
ResponsesRequest
,
ResponsesResponse
,
ResponsesResponse
,
StreamingResponsesResponse
,
StreamingResponsesResponse
,
TokenizeRequest
,
TokenizeResponse
,
TranscriptionRequest
,
TranscriptionRequest
,
TranscriptionResponseVariant
,
TranscriptionResponseVariant
,
TranslationRequest
,
TranslationRequest
,
...
@@ -80,8 +68,6 @@ from vllm.entrypoints.openai.serving_models import (
...
@@ -80,8 +68,6 @@ from vllm.entrypoints.openai.serving_models import (
OpenAIServingModels
,
OpenAIServingModels
,
)
)
from
vllm.entrypoints.openai.serving_responses
import
OpenAIServingResponses
from
vllm.entrypoints.openai.serving_responses
import
OpenAIServingResponses
from
vllm.entrypoints.openai.serving_tokenization
import
OpenAIServingTokenization
from
vllm.entrypoints.openai.serving_tokens
import
ServingTokens
from
vllm.entrypoints.openai.serving_transcription
import
(
from
vllm.entrypoints.openai.serving_transcription
import
(
OpenAIServingTranscription
,
OpenAIServingTranscription
,
OpenAIServingTranslation
,
OpenAIServingTranslation
,
...
@@ -92,6 +78,11 @@ from vllm.entrypoints.pooling.classify.serving import ServingClassification
...
@@ -92,6 +78,11 @@ from vllm.entrypoints.pooling.classify.serving import ServingClassification
from
vllm.entrypoints.pooling.embed.serving
import
OpenAIServingEmbedding
from
vllm.entrypoints.pooling.embed.serving
import
OpenAIServingEmbedding
from
vllm.entrypoints.pooling.pooling.serving
import
OpenAIServingPooling
from
vllm.entrypoints.pooling.pooling.serving
import
OpenAIServingPooling
from
vllm.entrypoints.pooling.score.serving
import
ServingScores
from
vllm.entrypoints.pooling.score.serving
import
ServingScores
from
vllm.entrypoints.serve.disagg.serving
import
ServingTokens
from
vllm.entrypoints.serve.elastic_ep.middleware
import
(
ScalingMiddleware
,
)
from
vllm.entrypoints.serve.tokenize.serving
import
OpenAIServingTokenization
from
vllm.entrypoints.tool_server
import
DemoToolServer
,
MCPToolServer
,
ToolServer
from
vllm.entrypoints.tool_server
import
DemoToolServer
,
MCPToolServer
,
ToolServer
from
vllm.entrypoints.utils
import
(
from
vllm.entrypoints.utils
import
(
cli_env_setup
,
cli_env_setup
,
...
@@ -109,8 +100,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
...
@@ -109,8 +100,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
from
vllm.utils.gc_utils
import
freeze_gc_heap
from
vllm.utils.gc_utils
import
freeze_gc_heap
from
vllm.utils.network_utils
import
is_valid_ipv6_address
from
vllm.utils.network_utils
import
is_valid_ipv6_address
from
vllm.utils.system_utils
import
decorate_logs
,
set_ulimit
from
vllm.utils.system_utils
import
decorate_logs
,
set_ulimit
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.metrics.prometheus
import
get_prometheus_registry
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
...
@@ -245,39 +234,6 @@ async def build_async_engine_client_from_engine_args(
...
@@ -245,39 +234,6 @@ async def build_async_engine_client_from_engine_args(
router
=
APIRouter
()
router
=
APIRouter
()
class
PrometheusResponse
(
Response
):
media_type
=
prometheus_client
.
CONTENT_TYPE_LATEST
def
mount_metrics
(
app
:
FastAPI
):
"""Mount prometheus metrics to a FastAPI app."""
registry
=
get_prometheus_registry
()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# instead of the default "application/json" which is incorrect.
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
Instrumentator
(
excluded_handlers
=
[
"/metrics"
,
"/health"
,
"/load"
,
"/ping"
,
"/version"
,
"/server_info"
,
],
registry
=
registry
,
).
add
().
instrument
(
app
).
expose
(
app
,
response_class
=
PrometheusResponse
)
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
"^/metrics(?P<path>.*)$"
)
app
.
routes
.
append
(
metrics_route
)
def
base
(
request
:
Request
)
->
OpenAIServing
:
def
base
(
request
:
Request
)
->
OpenAIServing
:
# Reuse the existing instance
# Reuse the existing instance
return
tokenization
(
request
)
return
tokenization
(
request
)
...
@@ -323,16 +279,6 @@ def generate_tokens(request: Request) -> ServingTokens | None:
...
@@ -323,16 +279,6 @@ def generate_tokens(request: Request) -> ServingTokens | None:
return
request
.
app
.
state
.
serving_tokens
return
request
.
app
.
state
.
serving_tokens
@
router
.
get
(
"/health"
,
response_class
=
Response
)
async
def
health
(
raw_request
:
Request
)
->
Response
:
"""Health check."""
try
:
await
engine_client
(
raw_request
).
check_health
()
return
Response
(
status_code
=
200
)
except
EngineDeadError
:
return
Response
(
status_code
=
503
)
@
router
.
get
(
"/load"
)
@
router
.
get
(
"/load"
)
async
def
get_server_load_metrics
(
request
:
Request
):
async
def
get_server_load_metrics
(
request
:
Request
):
# This endpoint returns the current server load metrics.
# This endpoint returns the current server load metrics.
...
@@ -352,167 +298,6 @@ async def get_server_load_metrics(request: Request):
...
@@ -352,167 +298,6 @@ async def get_server_load_metrics(request: Request):
return
JSONResponse
(
content
=
{
"server_load"
:
request
.
app
.
state
.
server_load_metrics
})
return
JSONResponse
(
content
=
{
"server_load"
:
request
.
app
.
state
.
server_load_metrics
})
@
router
.
post
(
"/pause"
)
async
def
pause_generation
(
raw_request
:
Request
,
wait_for_inflight_requests
:
bool
=
Query
(
False
),
clear_cache
:
bool
=
Query
(
True
),
)
->
JSONResponse
:
"""Pause generation requests to allow weight updates.
Args:
wait_for_inflight_requests: When ``True`` waits for in-flight
requests to finish before pausing. When ``False`` (default),
aborts any in-flight requests immediately.
clear_cache: Whether to clear KV/prefix caches after draining.
"""
engine
=
engine_client
(
raw_request
)
try
:
await
engine
.
pause_generation
(
wait_for_inflight_requests
=
wait_for_inflight_requests
,
clear_cache
=
clear_cache
,
)
return
JSONResponse
(
content
=
{
"status"
:
"paused"
},
status_code
=
HTTPStatus
.
OK
.
value
,
)
except
ValueError
as
err
:
return
JSONResponse
(
content
=
{
"error"
:
str
(
err
)},
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
)
except
Exception
as
err
:
# pragma: no cover - defensive
logger
.
exception
(
"Failed to pause generation"
)
return
JSONResponse
(
content
=
{
"error"
:
f
"Failed to pause generation:
{
err
}
"
},
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
)
@
router
.
post
(
"/resume"
)
async
def
resume_generation
(
raw_request
:
Request
)
->
JSONResponse
:
"""Resume generation after a pause."""
engine
=
engine_client
(
raw_request
)
try
:
await
engine
.
resume_generation
()
return
JSONResponse
(
content
=
{
"status"
:
"resumed"
},
status_code
=
HTTPStatus
.
OK
.
value
,
)
except
Exception
as
err
:
# pragma: no cover - defensive
logger
.
exception
(
"Failed to resume generation"
)
return
JSONResponse
(
content
=
{
"error"
:
f
"Failed to resume generation:
{
err
}
"
},
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
)
@
router
.
get
(
"/is_paused"
)
async
def
is_paused
(
raw_request
:
Request
)
->
JSONResponse
:
"""Return the current pause status."""
engine
=
engine_client
(
raw_request
)
try
:
paused
=
await
engine
.
is_paused
()
except
Exception
as
err
:
# pragma: no cover - defensive
logger
.
exception
(
"Failed to fetch pause status"
)
return
JSONResponse
(
content
=
{
"error"
:
f
"Failed to fetch pause status:
{
err
}
"
},
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
)
return
JSONResponse
(
content
=
{
"is_paused"
:
paused
})
@
router
.
post
(
"/tokenize"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_FOUND
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_IMPLEMENTED
.
value
:
{
"model"
:
ErrorResponse
},
},
)
@
with_cancellation
async
def
tokenize
(
request
:
TokenizeRequest
,
raw_request
:
Request
):
handler
=
tokenization
(
raw_request
)
try
:
generator
=
await
handler
.
create_tokenize
(
request
,
raw_request
)
except
NotImplementedError
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
NOT_IMPLEMENTED
.
value
,
detail
=
str
(
e
)
)
from
e
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
detail
=
str
(
e
)
)
from
e
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
TokenizeResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
@
router
.
post
(
"/detokenize"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_FOUND
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
@
with_cancellation
async
def
detokenize
(
request
:
DetokenizeRequest
,
raw_request
:
Request
):
handler
=
tokenization
(
raw_request
)
try
:
generator
=
await
handler
.
create_detokenize
(
request
,
raw_request
)
except
OverflowError
as
e
:
raise
RequestValidationError
(
errors
=
[
str
(
e
)])
from
e
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
detail
=
str
(
e
)
)
from
e
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
DetokenizeResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
def
maybe_register_tokenizer_info_endpoint
(
args
):
"""Conditionally register the tokenizer info endpoint if enabled."""
if
getattr
(
args
,
"enable_tokenizer_info_endpoint"
,
False
):
@
router
.
get
(
"/tokenizer_info"
)
async
def
get_tokenizer_info
(
raw_request
:
Request
):
"""Get comprehensive tokenizer information."""
result
=
await
tokenization
(
raw_request
).
get_tokenizer_info
()
return
JSONResponse
(
content
=
result
.
model_dump
(),
status_code
=
result
.
error
.
code
if
isinstance
(
result
,
ErrorResponse
)
else
200
,
)
@
router
.
get
(
"/v1/models"
)
@
router
.
get
(
"/v1/models"
)
async
def
show_available_models
(
raw_request
:
Request
):
async
def
show_available_models
(
raw_request
:
Request
):
handler
=
models
(
raw_request
)
handler
=
models
(
raw_request
)
...
@@ -898,33 +683,6 @@ if envs.VLLM_SERVER_DEV_MODE:
...
@@ -898,33 +683,6 @@ if envs.VLLM_SERVER_DEV_MODE:
await
engine_client
(
raw_request
).
reset_mm_cache
()
await
engine_client
(
raw_request
).
reset_mm_cache
()
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/sleep"
)
async
def
sleep
(
raw_request
:
Request
):
# get POST params
level
=
raw_request
.
query_params
.
get
(
"level"
,
"1"
)
await
engine_client
(
raw_request
).
sleep
(
int
(
level
))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/wake_up"
)
async
def
wake_up
(
raw_request
:
Request
):
tags
=
raw_request
.
query_params
.
getlist
(
"tags"
)
if
tags
==
[]:
# set to None to wake up all tags if no tags are provided
tags
=
None
logger
.
info
(
"wake up the engine with tags: %s"
,
tags
)
await
engine_client
(
raw_request
).
wake_up
(
tags
)
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return
Response
(
status_code
=
200
)
@
router
.
get
(
"/is_sleeping"
)
async
def
is_sleeping
(
raw_request
:
Request
):
logger
.
info
(
"check whether the engine is sleeping"
)
is_sleeping
=
await
engine_client
(
raw_request
).
is_sleeping
()
return
JSONResponse
(
content
=
{
"is_sleeping"
:
is_sleeping
})
@
router
.
post
(
"/collective_rpc"
)
@
router
.
post
(
"/collective_rpc"
)
async
def
collective_rpc
(
raw_request
:
Request
):
async
def
collective_rpc
(
raw_request
:
Request
):
try
:
try
:
...
@@ -952,138 +710,13 @@ if envs.VLLM_SERVER_DEV_MODE:
...
@@ -952,138 +710,13 @@ if envs.VLLM_SERVER_DEV_MODE:
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
response
:
list
[
Any
]
=
[]
response
:
list
[
Any
]
=
[]
for
result
in
results
:
for
result
in
results
:
if
result
is
None
or
isinstance
(
result
,
(
dict
,
list
)
)
:
if
result
is
None
or
isinstance
(
result
,
dict
|
list
):
response
.
append
(
result
)
response
.
append
(
result
)
else
:
else
:
response
.
append
(
str
(
result
))
response
.
append
(
str
(
result
))
return
JSONResponse
(
content
=
{
"results"
:
response
})
return
JSONResponse
(
content
=
{
"results"
:
response
})
@
router
.
post
(
"/scale_elastic_ep"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
OK
.
value
:
{
"model"
:
dict
},
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
REQUEST_TIMEOUT
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
async
def
scale_elastic_ep
(
raw_request
:
Request
):
try
:
body
=
await
raw_request
.
json
()
except
json
.
JSONDecodeError
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"Invalid JSON format"
)
from
e
# noqa: B904
new_data_parallel_size
=
body
.
get
(
"new_data_parallel_size"
)
drain_timeout
=
body
.
get
(
"drain_timeout"
,
120
)
# Default 2 minutes
if
new_data_parallel_size
is
None
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"new_data_parallel_size is required"
)
if
not
isinstance
(
new_data_parallel_size
,
int
)
or
new_data_parallel_size
<=
0
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"new_data_parallel_size must be a positive integer"
)
if
not
isinstance
(
drain_timeout
,
int
)
or
drain_timeout
<=
0
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"drain_timeout must be a positive integer"
)
# Set scaling flag to prevent new requests
global
_scaling_elastic_ep
_scaling_elastic_ep
=
True
client
=
engine_client
(
raw_request
)
try
:
await
client
.
scale_elastic_ep
(
new_data_parallel_size
,
drain_timeout
)
return
JSONResponse
(
{
"message"
:
f
"Scaled to
{
new_data_parallel_size
}
data parallel engines"
,
}
)
except
TimeoutError
as
e
:
raise
HTTPException
(
status_code
=
408
,
detail
=
"Scale failed due to request drain timeout "
f
"after
{
drain_timeout
}
seconds"
,
)
from
e
except
Exception
as
e
:
logger
.
error
(
"Scale failed: %s"
,
e
)
raise
HTTPException
(
status_code
=
500
,
detail
=
"Scale failed"
)
from
e
finally
:
_scaling_elastic_ep
=
False
@
router
.
post
(
"/is_scaling_elastic_ep"
)
async
def
is_scaling_elastic_ep
(
raw_request
:
Request
):
return
JSONResponse
({
"is_scaling_elastic_ep"
:
_scaling_elastic_ep
})
@
router
.
post
(
"/inference/v1/generate"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
OK
.
value
:
{
"content"
:
{
"text/event-stream"
:
{}}},
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_FOUND
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
@
with_cancellation
@
load_aware_call
async
def
generate
(
request
:
GenerateRequest
,
raw_request
:
Request
):
handler
=
generate_tokens
(
raw_request
)
if
handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
message
=
"The model does not support generate tokens API"
)
try
:
generator
=
await
handler
.
serve_tokens
(
request
,
raw_request
)
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
detail
=
str
(
e
)
)
from
e
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
GenerateResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
if
envs
.
VLLM_TORCH_PROFILER_DIR
:
logger
.
warning_once
(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
elif
envs
.
VLLM_TORCH_CUDA_PROFILE
:
logger
.
warning_once
(
"CUDA Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
if
envs
.
VLLM_TORCH_PROFILER_DIR
or
envs
.
VLLM_TORCH_CUDA_PROFILE
:
@
router
.
post
(
"/start_profile"
)
async
def
start_profile
(
raw_request
:
Request
):
logger
.
info
(
"Starting profiler..."
)
await
engine_client
(
raw_request
).
start_profile
()
logger
.
info
(
"Profiler started."
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/stop_profile"
)
async
def
stop_profile
(
raw_request
:
Request
):
logger
.
info
(
"Stopping profiler..."
)
await
engine_client
(
raw_request
).
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
return
Response
(
status_code
=
200
)
def
load_log_config
(
log_config_file
:
str
|
None
)
->
dict
|
None
:
def
load_log_config
(
log_config_file
:
str
|
None
)
->
dict
|
None
:
if
not
log_config_file
:
if
not
log_config_file
:
return
None
return
None
...
@@ -1176,41 +809,6 @@ class XRequestIdMiddleware:
...
@@ -1176,41 +809,6 @@ class XRequestIdMiddleware:
return
self
.
app
(
scope
,
receive
,
send_with_request_id
)
return
self
.
app
(
scope
,
receive
,
send_with_request_id
)
# Global variable to track scaling state
_scaling_elastic_ep
=
False
class
ScalingMiddleware
:
"""
Middleware that checks if the model is currently scaling and
returns a 503 Service Unavailable response if it is.
This middleware applies to all HTTP requests and prevents
processing when the model is in a scaling state.
"""
def
__init__
(
self
,
app
:
ASGIApp
)
->
None
:
self
.
app
=
app
def
__call__
(
self
,
scope
:
Scope
,
receive
:
Receive
,
send
:
Send
)
->
Awaitable
[
None
]:
if
scope
[
"type"
]
!=
"http"
:
return
self
.
app
(
scope
,
receive
,
send
)
# Check global scaling state
global
_scaling_elastic_ep
if
_scaling_elastic_ep
:
# Return 503 Service Unavailable response
response
=
JSONResponse
(
content
=
{
"error"
:
"The model is currently scaling. Please try again later."
},
status_code
=
503
,
)
return
response
(
scope
,
receive
,
send
)
return
self
.
app
(
scope
,
receive
,
send
)
def
_extract_content_from_chunk
(
chunk_data
:
dict
)
->
str
:
def
_extract_content_from_chunk
(
chunk_data
:
dict
)
->
str
:
"""Extract content from a streaming response chunk."""
"""Extract content from a streaming response chunk."""
try
:
try
:
...
@@ -1353,15 +951,10 @@ def build_app(args: Namespace) -> FastAPI:
...
@@ -1353,15 +951,10 @@ def build_app(args: Namespace) -> FastAPI:
)
)
else
:
else
:
app
=
FastAPI
(
lifespan
=
lifespan
)
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
state
.
args
=
args
from
vllm.entrypoints.serve
import
register_vllm_serve_api_routers
if
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
register_vllm_serve_api_routers
(
app
)
logger
.
warning
(
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
from
vllm.entrypoints.dynamic_lora
import
register_dynamic_lora_routes
register_dynamic_lora_routes
(
router
)
from
vllm.entrypoints.sagemaker.routes
import
register_sagemaker_routes
from
vllm.entrypoints.sagemaker.routes
import
register_sagemaker_routes
...
@@ -1370,8 +963,6 @@ def build_app(args: Namespace) -> FastAPI:
...
@@ -1370,8 +963,6 @@ def build_app(args: Namespace) -> FastAPI:
app
.
root_path
=
args
.
root_path
app
.
root_path
=
args
.
root_path
mount_metrics
(
app
)
from
vllm.entrypoints.pooling
import
register_pooling_api_routers
from
vllm.entrypoints.pooling
import
register_pooling_api_routers
register_pooling_api_routers
(
app
)
register_pooling_api_routers
(
app
)
...
@@ -1462,31 +1053,6 @@ def build_app(args: Namespace) -> FastAPI:
...
@@ -1462,31 +1053,6 @@ def build_app(args: Namespace) -> FastAPI:
)
)
app
=
sagemaker_standards
.
bootstrap
(
app
)
app
=
sagemaker_standards
.
bootstrap
(
app
)
# Optional endpoints
if
args
.
tokens_only
:
@
app
.
post
(
"/abort_requests"
)
async
def
abort_requests
(
raw_request
:
Request
):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try
:
body
=
await
raw_request
.
json
()
except
json
.
JSONDecodeError
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
f
"JSON decode error:
{
e
}
"
,
)
from
e
request_ids
=
body
.
get
(
"request_ids"
)
if
request_ids
is
None
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
"Missing 'request_ids' in request body"
,
)
# Abort requests in background
asyncio
.
create_task
(
engine_client
(
raw_request
).
abort
(
request_ids
))
return
Response
(
status_code
=
200
)
return
app
return
app
...
@@ -1515,7 +1081,7 @@ async def init_app_state(
...
@@ -1515,7 +1081,7 @@ async def init_app_state(
state
.
engine_client
=
engine_client
state
.
engine_client
=
engine_client
state
.
log_stats
=
not
args
.
disable_log_stats
state
.
log_stats
=
not
args
.
disable_log_stats
state
.
vllm_config
=
vllm_config
state
.
vllm_config
=
vllm_config
state
.
args
=
args
supported_tasks
=
await
engine_client
.
get_supported_tasks
()
supported_tasks
=
await
engine_client
.
get_supported_tasks
()
logger
.
info
(
"Supported tasks: %s"
,
supported_tasks
)
logger
.
info
(
"Supported tasks: %s"
,
supported_tasks
)
...
@@ -1839,7 +1405,6 @@ async def run_server_worker(
...
@@ -1839,7 +1405,6 @@ async def run_server_worker(
args
,
args
,
client_config
=
client_config
,
client_config
=
client_config
,
)
as
engine_client
:
)
as
engine_client
:
maybe_register_tokenizer_info_endpoint
(
args
)
app
=
build_app
(
args
)
app
=
build_app
(
args
)
await
init_app_state
(
engine_client
,
app
.
state
,
args
)
await
init_app_state
(
engine_client
,
app
.
state
,
args
)
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
3f42b05f
...
@@ -74,8 +74,6 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -74,8 +74,6 @@ from vllm.entrypoints.openai.protocol import (
ErrorResponse
,
ErrorResponse
,
FunctionCall
,
FunctionCall
,
FunctionDefinition
,
FunctionDefinition
,
GenerateRequest
,
GenerateResponse
,
ResponsesRequest
,
ResponsesRequest
,
TokenizeChatRequest
,
TokenizeChatRequest
,
TokenizeCompletionRequest
,
TokenizeCompletionRequest
,
...
@@ -87,6 +85,7 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -87,6 +85,7 @@ from vllm.entrypoints.openai.protocol import (
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.entrypoints.renderer
import
BaseRenderer
,
CompletionRenderer
,
RenderConfig
from
vllm.entrypoints.renderer
import
BaseRenderer
,
CompletionRenderer
,
RenderConfig
from
vllm.entrypoints.serve.disagg.protocol
import
GenerateRequest
,
GenerateResponse
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.data
import
TokensPrompt
as
EngineTokensPrompt
from
vllm.inputs.data
import
TokensPrompt
as
EngineTokensPrompt
...
...
vllm/entrypoints/sagemaker/routes.py
View file @
3f42b05f
...
@@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import (
...
@@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import (
completion
,
completion
,
create_chat_completion
,
create_chat_completion
,
create_completion
,
create_completion
,
health
,
validate_json_request
,
validate_json_request
,
)
)
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
...
@@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import (
...
@@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import (
score
,
score
,
)
)
from
vllm.entrypoints.pooling.score.protocol
import
RerankRequest
,
ScoreRequest
from
vllm.entrypoints.pooling.score.protocol
import
RerankRequest
,
ScoreRequest
from
vllm.entrypoints.serve.instrumentator.health
import
health
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
# (requires typing_extensions >= 4.13)
...
...
vllm/entrypoints/serve/__init__.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
fastapi
import
FastAPI
def
register_vllm_serve_api_routers
(
app
:
FastAPI
):
from
vllm.entrypoints.serve.lora.api_router
import
(
attach_router
as
attach_lora_router
,
)
attach_lora_router
(
app
)
from
vllm.entrypoints.serve.elastic_ep.api_router
import
(
attach_router
as
attach_elastic_ep_router
,
)
attach_elastic_ep_router
(
app
)
from
vllm.entrypoints.serve.profile.api_router
import
(
attach_router
as
attach_profile_router
,
)
attach_profile_router
(
app
)
from
vllm.entrypoints.serve.sleep.api_router
import
(
attach_router
as
attach_sleep_router
,
)
attach_sleep_router
(
app
)
from
vllm.entrypoints.serve.tokenize.api_router
import
(
attach_router
as
attach_tokenize_router
,
)
attach_tokenize_router
(
app
)
from
vllm.entrypoints.serve.disagg.api_router
import
(
attach_router
as
attach_disagg_router
,
)
attach_disagg_router
(
app
)
from
vllm.entrypoints.serve.rlhf.api_router
import
(
attach_router
as
attach_rlhf_router
,
)
attach_rlhf_router
(
app
)
from
vllm.entrypoints.serve.instrumentator.metrics
import
(
attach_router
as
attach_metrics_router
,
)
attach_metrics_router
(
app
)
from
vllm.entrypoints.serve.instrumentator.health
import
(
attach_router
as
attach_health_router
,
)
attach_health_router
(
app
)
vllm/entrypoints/serve/disagg/__init__.py
0 → 100644
View file @
3f42b05f
vllm/entrypoints/serve/disagg/api_router.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
json
from
http
import
HTTPStatus
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
HTTPException
,
Request
,
Response
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.openai.api_server
import
validate_json_request
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
)
from
vllm.entrypoints.serve.disagg.protocol
import
(
GenerateRequest
,
GenerateResponse
,
)
from
vllm.entrypoints.serve.disagg.serving
import
(
ServingTokens
,
)
from
vllm.entrypoints.serve.tokenize.serving
import
OpenAIServingTokenization
from
vllm.entrypoints.utils
import
(
load_aware_call
,
with_cancellation
,
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
tokenization
(
request
:
Request
)
->
OpenAIServingTokenization
:
return
request
.
app
.
state
.
openai_serving_tokenization
def
generate_tokens
(
request
:
Request
)
->
ServingTokens
|
None
:
return
request
.
app
.
state
.
serving_tokens
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
router
=
APIRouter
()
@
router
.
post
(
"/inference/v1/generate"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
OK
.
value
:
{
"content"
:
{
"text/event-stream"
:
{}}},
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_FOUND
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
@
with_cancellation
@
load_aware_call
async
def
generate
(
request
:
GenerateRequest
,
raw_request
:
Request
):
handler
=
generate_tokens
(
raw_request
)
if
handler
is
None
:
return
tokenization
(
raw_request
).
create_error_response
(
message
=
"The model does not support generate tokens API"
)
try
:
generator
=
await
handler
.
serve_tokens
(
request
,
raw_request
)
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
detail
=
str
(
e
)
)
from
e
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
GenerateResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
def
attach_router
(
app
:
FastAPI
):
if
getattr
(
app
.
state
.
args
,
"tokens_only"
,
False
):
@
router
.
post
(
"/abort_requests"
)
async
def
abort_requests
(
raw_request
:
Request
):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try
:
body
=
await
raw_request
.
json
()
except
json
.
JSONDecodeError
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
f
"JSON decode error:
{
e
}
"
,
)
from
e
request_ids
=
body
.
get
(
"request_ids"
)
if
request_ids
is
None
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
"Missing 'request_ids' in request body"
,
)
# Abort requests in background
asyncio
.
create_task
(
engine_client
(
raw_request
).
abort
(
request_ids
))
return
Response
(
status_code
=
200
)
app
.
include_router
(
router
)
vllm/entrypoints/serve/disagg/protocol.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
from
pydantic
import
BaseModel
,
Field
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionLogProbs
,
Logprob
,
SamplingParams
,
StreamOptions
,
)
from
vllm.utils
import
random_uuid
####### Tokens IN <> Tokens OUT #######
class
GenerateRequest
(
BaseModel
):
request_id
:
str
=
Field
(
default_factory
=
lambda
:
f
"
{
random_uuid
()
}
"
,
description
=
(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids
:
list
[
int
]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features
:
str
|
None
=
None
"""The processed MM inputs for the model."""
sampling_params
:
SamplingParams
"""The sampling parameters for the model."""
model
:
str
|
None
=
None
stream
:
bool
|
None
=
False
stream_options
:
StreamOptions
|
None
=
None
cache_salt
:
str
|
None
=
Field
(
default
=
None
,
description
=
(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params
:
dict
[
str
,
Any
]
|
None
=
Field
(
default
=
None
,
description
=
"KVTransfer parameters used for disaggregated serving."
,
)
class
GenerateResponseChoice
(
BaseModel
):
index
:
int
logprobs
:
ChatCompletionLogProbs
|
None
=
None
# per OpenAI spec this is the default
finish_reason
:
str
|
None
=
"stop"
token_ids
:
list
[
int
]
|
None
=
None
class
GenerateResponse
(
BaseModel
):
request_id
:
str
=
Field
(
default_factory
=
lambda
:
f
"
{
random_uuid
()
}
"
,
description
=
(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
choices
:
list
[
GenerateResponseChoice
]
prompt_logprobs
:
list
[
dict
[
int
,
Logprob
]
|
None
]
|
None
=
None
kv_transfer_params
:
dict
[
str
,
Any
]
|
None
=
Field
(
default
=
None
,
description
=
"KVTransfer parameters used for disaggregated serving."
,
)
vllm/entrypoints/
openai
/serving
_tokens
.py
→
vllm/entrypoints/
serve/disagg
/serving.py
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
asyncio
import
time
import
time
from
collections.abc
import
AsyncGenerator
from
collections.abc
import
AsyncGenerator
...
@@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProbs
,
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
ChatCompletionLogProbsContent
,
ErrorResponse
,
ErrorResponse
,
GenerateRequest
,
GenerateResponse
,
GenerateResponseChoice
,
PromptTokenUsageInfo
,
PromptTokenUsageInfo
,
RequestResponseMetadata
,
RequestResponseMetadata
,
UsageInfo
,
UsageInfo
,
)
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
,
clamp_prompt_logprobs
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
,
clamp_prompt_logprobs
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.serve.disagg.protocol
import
(
GenerateRequest
,
GenerateResponse
,
GenerateResponseChoice
,
)
from
vllm.inputs.data
import
TokensPrompt
as
EngineTokensPrompt
from
vllm.inputs.data
import
TokensPrompt
as
EngineTokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.logprobs
import
Logprob
...
...
vllm/entrypoints/serve/elastic_ep/__init__.py
0 → 100644
View file @
3f42b05f
vllm/entrypoints/serve/elastic_ep/api_router.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
http
import
HTTPStatus
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
HTTPException
,
Request
from
fastapi.responses
import
JSONResponse
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.openai.api_server
import
validate_json_request
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
)
from
vllm.entrypoints.serve.elastic_ep.middleware
import
(
get_scaling_elastic_ep
,
set_scaling_elastic_ep
,
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
router
=
APIRouter
()
@
router
.
post
(
"/scale_elastic_ep"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
OK
.
value
:
{
"model"
:
dict
},
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
REQUEST_TIMEOUT
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
async
def
scale_elastic_ep
(
raw_request
:
Request
):
try
:
body
=
await
raw_request
.
json
()
except
json
.
JSONDecodeError
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"Invalid JSON format"
)
from
e
# noqa: B904
new_data_parallel_size
=
body
.
get
(
"new_data_parallel_size"
)
drain_timeout
=
body
.
get
(
"drain_timeout"
,
120
)
# Default 2 minutes
if
new_data_parallel_size
is
None
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"new_data_parallel_size is required"
)
if
not
isinstance
(
new_data_parallel_size
,
int
)
or
new_data_parallel_size
<=
0
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"new_data_parallel_size must be a positive integer"
,
)
if
not
isinstance
(
drain_timeout
,
int
)
or
drain_timeout
<=
0
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"drain_timeout must be a positive integer"
)
# Set scaling flag to prevent new requests
set_scaling_elastic_ep
(
True
)
client
=
engine_client
(
raw_request
)
try
:
await
client
.
scale_elastic_ep
(
new_data_parallel_size
,
drain_timeout
)
return
JSONResponse
(
{
"message"
:
f
"Scaled to
{
new_data_parallel_size
}
data parallel engines"
,
}
)
except
TimeoutError
as
e
:
raise
HTTPException
(
status_code
=
408
,
detail
=
"Scale failed due to request drain timeout "
f
"after
{
drain_timeout
}
seconds"
,
)
from
e
except
Exception
as
e
:
logger
.
error
(
"Scale failed: %s"
,
e
)
raise
HTTPException
(
status_code
=
500
,
detail
=
"Scale failed"
)
from
e
finally
:
set_scaling_elastic_ep
(
False
)
@
router
.
post
(
"/is_scaling_elastic_ep"
)
async
def
is_scaling_elastic_ep
(
raw_request
:
Request
):
return
JSONResponse
({
"is_scaling_elastic_ep"
:
get_scaling_elastic_ep
()})
def
attach_router
(
app
:
FastAPI
):
app
.
include_router
(
router
)
vllm/entrypoints/serve/elastic_ep/middleware.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Awaitable
from
fastapi.responses
import
JSONResponse
from
starlette.types
import
ASGIApp
,
Receive
,
Scope
,
Send
# Global variable to track scaling state
_scaling_elastic_ep
=
False
def
get_scaling_elastic_ep
():
return
_scaling_elastic_ep
def
set_scaling_elastic_ep
(
value
):
global
_scaling_elastic_ep
_scaling_elastic_ep
=
value
class
ScalingMiddleware
:
"""
Middleware that checks if the model is currently scaling and
returns a 503 Service Unavailable response if it is.
This middleware applies to all HTTP requests and prevents
processing when the model is in a scaling state.
"""
def
__init__
(
self
,
app
:
ASGIApp
)
->
None
:
self
.
app
=
app
def
__call__
(
self
,
scope
:
Scope
,
receive
:
Receive
,
send
:
Send
)
->
Awaitable
[
None
]:
if
scope
[
"type"
]
!=
"http"
:
return
self
.
app
(
scope
,
receive
,
send
)
# Check global scaling state
if
get_scaling_elastic_ep
():
# Return 503 Service Unavailable response
response
=
JSONResponse
(
content
=
{
"error"
:
"The model is currently scaling. Please try again later."
},
status_code
=
503
,
)
return
response
(
scope
,
receive
,
send
)
return
self
.
app
(
scope
,
receive
,
send
)
vllm/entrypoints/serve/instrumentator/__init__.py
0 → 100644
View file @
3f42b05f
vllm/entrypoints/serve/instrumentator/health.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
fastapi
import
APIRouter
,
Request
from
fastapi.responses
import
Response
from
vllm.engine.protocol
import
EngineClient
from
vllm.logger
import
init_logger
from
vllm.v1.engine.exceptions
import
EngineDeadError
logger
=
init_logger
(
__name__
)
router
=
APIRouter
()
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
@
router
.
get
(
"/health"
,
response_class
=
Response
)
async
def
health
(
raw_request
:
Request
)
->
Response
:
"""Health check."""
try
:
await
engine_client
(
raw_request
).
check_health
()
return
Response
(
status_code
=
200
)
except
EngineDeadError
:
return
Response
(
status_code
=
503
)
def
attach_router
(
app
):
app
.
include_router
(
router
)
vllm/entrypoints/serve/instrumentator/metrics.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
re
import
prometheus_client
from
fastapi
import
FastAPI
,
Response
from
prometheus_client
import
make_asgi_app
from
prometheus_fastapi_instrumentator
import
Instrumentator
from
starlette.routing
import
Mount
from
vllm.v1.metrics.prometheus
import
get_prometheus_registry
class
PrometheusResponse
(
Response
):
media_type
=
prometheus_client
.
CONTENT_TYPE_LATEST
def
attach_router
(
app
:
FastAPI
):
"""Mount prometheus metrics to a FastAPI app."""
registry
=
get_prometheus_registry
()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# instead of the default "application/json" which is incorrect.
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
Instrumentator
(
excluded_handlers
=
[
"/metrics"
,
"/health"
,
"/load"
,
"/ping"
,
"/version"
,
"/server_info"
,
],
registry
=
registry
,
).
add
().
instrument
(
app
).
expose
(
app
,
response_class
=
PrometheusResponse
)
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
"^/metrics(?P<path>.*)$"
)
app
.
routes
.
append
(
metrics_route
)
vllm/entrypoints/serve/lora/__init__.py
0 → 100644
View file @
3f42b05f
vllm/entrypoints/
dynamic_lora
.py
→
vllm/entrypoints/
serve/lora/api_router
.py
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
model_hosting_container_standards.sagemaker
as
sagemaker_standards
import
model_hosting_container_standards.sagemaker
as
sagemaker_standards
from
fastapi
import
APIRouter
,
Depends
,
Request
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
Response
from
fastapi.responses
import
JSONResponse
,
Response
from
vllm
import
envs
from
vllm.entrypoints.openai.api_server
import
models
,
validate_json_request
from
vllm.entrypoints.openai.api_server
import
models
,
validate_json_request
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
ErrorResponse
,
...
@@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
...
@@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
router
=
APIRouter
()
def
register_dynamic_lora_routes
(
router
:
APIRouter
):
def
attach_router
(
app
:
FastAPI
):
if
not
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
"""If LoRA dynamic loading & unloading is not enabled, do nothing."""
return
logger
.
warning
(
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
@
sagemaker_standards
.
register_load_adapter_handler
(
@
sagemaker_standards
.
register_load_adapter_handler
(
request_shape
=
{
request_shape
=
{
"lora_name"
:
"body.name"
,
"lora_name"
:
"body.name"
,
...
@@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter):
...
@@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter):
return
Response
(
status_code
=
200
,
content
=
response
)
return
Response
(
status_code
=
200
,
content
=
response
)
return
router
# register the router
app
.
include_router
(
router
)
vllm/entrypoints/serve/profile/__init__.py
0 → 100644
View file @
3f42b05f
vllm/entrypoints/serve/profile/api_router.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
fastapi
import
APIRouter
,
FastAPI
,
Request
from
fastapi.responses
import
Response
import
vllm.envs
as
envs
from
vllm.engine.protocol
import
EngineClient
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
router
=
APIRouter
()
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
@
router
.
post
(
"/start_profile"
)
async
def
start_profile
(
raw_request
:
Request
):
logger
.
info
(
"Starting profiler..."
)
await
engine_client
(
raw_request
).
start_profile
()
logger
.
info
(
"Profiler started."
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/stop_profile"
)
async
def
stop_profile
(
raw_request
:
Request
):
logger
.
info
(
"Stopping profiler..."
)
await
engine_client
(
raw_request
).
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
return
Response
(
status_code
=
200
)
def
attach_router
(
app
:
FastAPI
):
if
envs
.
VLLM_TORCH_PROFILER_DIR
:
logger
.
warning_once
(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
elif
envs
.
VLLM_TORCH_CUDA_PROFILE
:
logger
.
warning_once
(
"CUDA Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
if
envs
.
VLLM_TORCH_PROFILER_DIR
or
envs
.
VLLM_TORCH_CUDA_PROFILE
:
app
.
include_router
(
router
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment