Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3f42b05f
Unverified
Commit
3f42b05f
authored
Dec 03, 2025
by
Chauncey
Committed by
GitHub
Dec 03, 2025
Browse files
[Refactor] [1/N] to simplify the vLLM serving architecture (#28040)
Signed-off-by:
chaunceyjiang
<
chaunceyjiang@gmail.com
>
parent
69520bc6
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
570 additions
and
455 deletions
+570
-455
tests/entrypoints/openai/test_basic.py
tests/entrypoints/openai/test_basic.py
+1
-1
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+1
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+10
-445
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+1
-2
vllm/entrypoints/sagemaker/routes.py
vllm/entrypoints/sagemaker/routes.py
+1
-1
vllm/entrypoints/serve/__init__.py
vllm/entrypoints/serve/__init__.py
+60
-0
vllm/entrypoints/serve/disagg/__init__.py
vllm/entrypoints/serve/disagg/__init__.py
+0
-0
vllm/entrypoints/serve/disagg/api_router.py
vllm/entrypoints/serve/disagg/api_router.py
+110
-0
vllm/entrypoints/serve/disagg/protocol.py
vllm/entrypoints/serve/disagg/protocol.py
+90
-0
vllm/entrypoints/serve/disagg/serving.py
vllm/entrypoints/serve/disagg/serving.py
+7
-3
vllm/entrypoints/serve/elastic_ep/__init__.py
vllm/entrypoints/serve/elastic_ep/__init__.py
+0
-0
vllm/entrypoints/serve/elastic_ep/api_router.py
vllm/entrypoints/serve/elastic_ep/api_router.py
+96
-0
vllm/entrypoints/serve/elastic_ep/middleware.py
vllm/entrypoints/serve/elastic_ep/middleware.py
+49
-0
vllm/entrypoints/serve/instrumentator/__init__.py
vllm/entrypoints/serve/instrumentator/__init__.py
+0
-0
vllm/entrypoints/serve/instrumentator/health.py
vllm/entrypoints/serve/instrumentator/health.py
+33
-0
vllm/entrypoints/serve/instrumentator/metrics.py
vllm/entrypoints/serve/instrumentator/metrics.py
+46
-0
vllm/entrypoints/serve/lora/__init__.py
vllm/entrypoints/serve/lora/__init__.py
+0
-0
vllm/entrypoints/serve/lora/api_router.py
vllm/entrypoints/serve/lora/api_router.py
+16
-3
vllm/entrypoints/serve/profile/__init__.py
vllm/entrypoints/serve/profile/__init__.py
+0
-0
vllm/entrypoints/serve/profile/api_router.py
vllm/entrypoints/serve/profile/api_router.py
+49
-0
No files found.
tests/entrypoints/openai/test_basic.py
View file @
3f42b05f
...
...
@@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer):
@
pytest
.
mark
.
asyncio
async
def
test_health_check_engine_dead_error
():
# Import the health function directly to test it in isolation
from
vllm.entrypoints.
openai.api_server
import
health
from
vllm.entrypoints.
serve.instrumentator.health
import
health
# Create a mock request that simulates what FastAPI would provide
mock_request
=
Mock
(
spec
=
Request
)
...
...
vllm/entrypoints/api_server.py
View file @
3f42b05f
...
...
@@ -118,6 +118,7 @@ async def init_app(
)
)
app
.
state
.
engine_client
=
engine
app
.
state
.
args
=
args
return
app
...
...
vllm/entrypoints/openai/api_server.py
View file @
3f42b05f
...
...
@@ -20,21 +20,15 @@ from http import HTTPStatus
from
typing
import
Annotated
,
Any
,
Literal
import
model_hosting_container_standards.sagemaker
as
sagemaker_standards
import
prometheus_client
import
pydantic
import
regex
as
re
import
uvloop
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
Form
,
HTTPException
,
Query
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
prometheus_client
import
make_asgi_app
from
prometheus_fastapi_instrumentator
import
Instrumentator
from
starlette.concurrency
import
iterate_in_threadpool
from
starlette.datastructures
import
URL
,
Headers
,
MutableHeaders
,
State
from
starlette.routing
import
Mount
from
starlette.types
import
ASGIApp
,
Message
,
Receive
,
Scope
,
Send
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
...
...
@@ -56,17 +50,11 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionResponse
,
CompletionRequest
,
CompletionResponse
,
DetokenizeRequest
,
DetokenizeResponse
,
ErrorInfo
,
ErrorResponse
,
GenerateRequest
,
GenerateResponse
,
ResponsesRequest
,
ResponsesResponse
,
StreamingResponsesResponse
,
TokenizeRequest
,
TokenizeResponse
,
TranscriptionRequest
,
TranscriptionResponseVariant
,
TranslationRequest
,
...
...
@@ -80,8 +68,6 @@ from vllm.entrypoints.openai.serving_models import (
OpenAIServingModels
,
)
from
vllm.entrypoints.openai.serving_responses
import
OpenAIServingResponses
from
vllm.entrypoints.openai.serving_tokenization
import
OpenAIServingTokenization
from
vllm.entrypoints.openai.serving_tokens
import
ServingTokens
from
vllm.entrypoints.openai.serving_transcription
import
(
OpenAIServingTranscription
,
OpenAIServingTranslation
,
...
...
@@ -92,6 +78,11 @@ from vllm.entrypoints.pooling.classify.serving import ServingClassification
from
vllm.entrypoints.pooling.embed.serving
import
OpenAIServingEmbedding
from
vllm.entrypoints.pooling.pooling.serving
import
OpenAIServingPooling
from
vllm.entrypoints.pooling.score.serving
import
ServingScores
from
vllm.entrypoints.serve.disagg.serving
import
ServingTokens
from
vllm.entrypoints.serve.elastic_ep.middleware
import
(
ScalingMiddleware
,
)
from
vllm.entrypoints.serve.tokenize.serving
import
OpenAIServingTokenization
from
vllm.entrypoints.tool_server
import
DemoToolServer
,
MCPToolServer
,
ToolServer
from
vllm.entrypoints.utils
import
(
cli_env_setup
,
...
...
@@ -109,8 +100,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
from
vllm.utils.gc_utils
import
freeze_gc_heap
from
vllm.utils.network_utils
import
is_valid_ipv6_address
from
vllm.utils.system_utils
import
decorate_logs
,
set_ulimit
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.metrics.prometheus
import
get_prometheus_registry
from
vllm.version
import
__version__
as
VLLM_VERSION
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
...
...
@@ -245,39 +234,6 @@ async def build_async_engine_client_from_engine_args(
router
=
APIRouter
()
class
PrometheusResponse
(
Response
):
media_type
=
prometheus_client
.
CONTENT_TYPE_LATEST
def
mount_metrics
(
app
:
FastAPI
):
"""Mount prometheus metrics to a FastAPI app."""
registry
=
get_prometheus_registry
()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# instead of the default "application/json" which is incorrect.
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
Instrumentator
(
excluded_handlers
=
[
"/metrics"
,
"/health"
,
"/load"
,
"/ping"
,
"/version"
,
"/server_info"
,
],
registry
=
registry
,
).
add
().
instrument
(
app
).
expose
(
app
,
response_class
=
PrometheusResponse
)
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
"^/metrics(?P<path>.*)$"
)
app
.
routes
.
append
(
metrics_route
)
def
base
(
request
:
Request
)
->
OpenAIServing
:
# Reuse the existing instance
return
tokenization
(
request
)
...
...
@@ -323,16 +279,6 @@ def generate_tokens(request: Request) -> ServingTokens | None:
return
request
.
app
.
state
.
serving_tokens
@
router
.
get
(
"/health"
,
response_class
=
Response
)
async
def
health
(
raw_request
:
Request
)
->
Response
:
"""Health check."""
try
:
await
engine_client
(
raw_request
).
check_health
()
return
Response
(
status_code
=
200
)
except
EngineDeadError
:
return
Response
(
status_code
=
503
)
@
router
.
get
(
"/load"
)
async
def
get_server_load_metrics
(
request
:
Request
):
# This endpoint returns the current server load metrics.
...
...
@@ -352,167 +298,6 @@ async def get_server_load_metrics(request: Request):
return
JSONResponse
(
content
=
{
"server_load"
:
request
.
app
.
state
.
server_load_metrics
})
@
router
.
post
(
"/pause"
)
async
def
pause_generation
(
raw_request
:
Request
,
wait_for_inflight_requests
:
bool
=
Query
(
False
),
clear_cache
:
bool
=
Query
(
True
),
)
->
JSONResponse
:
"""Pause generation requests to allow weight updates.
Args:
wait_for_inflight_requests: When ``True`` waits for in-flight
requests to finish before pausing. When ``False`` (default),
aborts any in-flight requests immediately.
clear_cache: Whether to clear KV/prefix caches after draining.
"""
engine
=
engine_client
(
raw_request
)
try
:
await
engine
.
pause_generation
(
wait_for_inflight_requests
=
wait_for_inflight_requests
,
clear_cache
=
clear_cache
,
)
return
JSONResponse
(
content
=
{
"status"
:
"paused"
},
status_code
=
HTTPStatus
.
OK
.
value
,
)
except
ValueError
as
err
:
return
JSONResponse
(
content
=
{
"error"
:
str
(
err
)},
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
)
except
Exception
as
err
:
# pragma: no cover - defensive
logger
.
exception
(
"Failed to pause generation"
)
return
JSONResponse
(
content
=
{
"error"
:
f
"Failed to pause generation:
{
err
}
"
},
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
)
@
router
.
post
(
"/resume"
)
async
def
resume_generation
(
raw_request
:
Request
)
->
JSONResponse
:
"""Resume generation after a pause."""
engine
=
engine_client
(
raw_request
)
try
:
await
engine
.
resume_generation
()
return
JSONResponse
(
content
=
{
"status"
:
"resumed"
},
status_code
=
HTTPStatus
.
OK
.
value
,
)
except
Exception
as
err
:
# pragma: no cover - defensive
logger
.
exception
(
"Failed to resume generation"
)
return
JSONResponse
(
content
=
{
"error"
:
f
"Failed to resume generation:
{
err
}
"
},
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
)
@
router
.
get
(
"/is_paused"
)
async
def
is_paused
(
raw_request
:
Request
)
->
JSONResponse
:
"""Return the current pause status."""
engine
=
engine_client
(
raw_request
)
try
:
paused
=
await
engine
.
is_paused
()
except
Exception
as
err
:
# pragma: no cover - defensive
logger
.
exception
(
"Failed to fetch pause status"
)
return
JSONResponse
(
content
=
{
"error"
:
f
"Failed to fetch pause status:
{
err
}
"
},
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
)
return
JSONResponse
(
content
=
{
"is_paused"
:
paused
})
@
router
.
post
(
"/tokenize"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_FOUND
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_IMPLEMENTED
.
value
:
{
"model"
:
ErrorResponse
},
},
)
@
with_cancellation
async
def
tokenize
(
request
:
TokenizeRequest
,
raw_request
:
Request
):
handler
=
tokenization
(
raw_request
)
try
:
generator
=
await
handler
.
create_tokenize
(
request
,
raw_request
)
except
NotImplementedError
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
NOT_IMPLEMENTED
.
value
,
detail
=
str
(
e
)
)
from
e
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
detail
=
str
(
e
)
)
from
e
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
TokenizeResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
@
router
.
post
(
"/detokenize"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_FOUND
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
@
with_cancellation
async
def
detokenize
(
request
:
DetokenizeRequest
,
raw_request
:
Request
):
handler
=
tokenization
(
raw_request
)
try
:
generator
=
await
handler
.
create_detokenize
(
request
,
raw_request
)
except
OverflowError
as
e
:
raise
RequestValidationError
(
errors
=
[
str
(
e
)])
from
e
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
detail
=
str
(
e
)
)
from
e
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
DetokenizeResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
def
maybe_register_tokenizer_info_endpoint
(
args
):
"""Conditionally register the tokenizer info endpoint if enabled."""
if
getattr
(
args
,
"enable_tokenizer_info_endpoint"
,
False
):
@
router
.
get
(
"/tokenizer_info"
)
async
def
get_tokenizer_info
(
raw_request
:
Request
):
"""Get comprehensive tokenizer information."""
result
=
await
tokenization
(
raw_request
).
get_tokenizer_info
()
return
JSONResponse
(
content
=
result
.
model_dump
(),
status_code
=
result
.
error
.
code
if
isinstance
(
result
,
ErrorResponse
)
else
200
,
)
@
router
.
get
(
"/v1/models"
)
async
def
show_available_models
(
raw_request
:
Request
):
handler
=
models
(
raw_request
)
...
...
@@ -898,33 +683,6 @@ if envs.VLLM_SERVER_DEV_MODE:
await
engine_client
(
raw_request
).
reset_mm_cache
()
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/sleep"
)
async
def
sleep
(
raw_request
:
Request
):
# get POST params
level
=
raw_request
.
query_params
.
get
(
"level"
,
"1"
)
await
engine_client
(
raw_request
).
sleep
(
int
(
level
))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/wake_up"
)
async
def
wake_up
(
raw_request
:
Request
):
tags
=
raw_request
.
query_params
.
getlist
(
"tags"
)
if
tags
==
[]:
# set to None to wake up all tags if no tags are provided
tags
=
None
logger
.
info
(
"wake up the engine with tags: %s"
,
tags
)
await
engine_client
(
raw_request
).
wake_up
(
tags
)
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return
Response
(
status_code
=
200
)
@
router
.
get
(
"/is_sleeping"
)
async
def
is_sleeping
(
raw_request
:
Request
):
logger
.
info
(
"check whether the engine is sleeping"
)
is_sleeping
=
await
engine_client
(
raw_request
).
is_sleeping
()
return
JSONResponse
(
content
=
{
"is_sleeping"
:
is_sleeping
})
@
router
.
post
(
"/collective_rpc"
)
async
def
collective_rpc
(
raw_request
:
Request
):
try
:
...
...
@@ -952,138 +710,13 @@ if envs.VLLM_SERVER_DEV_MODE:
return
Response
(
status_code
=
200
)
response
:
list
[
Any
]
=
[]
for
result
in
results
:
if
result
is
None
or
isinstance
(
result
,
(
dict
,
list
)
)
:
if
result
is
None
or
isinstance
(
result
,
dict
|
list
):
response
.
append
(
result
)
else
:
response
.
append
(
str
(
result
))
return
JSONResponse
(
content
=
{
"results"
:
response
})
@
router
.
post
(
"/scale_elastic_ep"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
OK
.
value
:
{
"model"
:
dict
},
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
REQUEST_TIMEOUT
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
async
def
scale_elastic_ep
(
raw_request
:
Request
):
try
:
body
=
await
raw_request
.
json
()
except
json
.
JSONDecodeError
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"Invalid JSON format"
)
from
e
# noqa: B904
new_data_parallel_size
=
body
.
get
(
"new_data_parallel_size"
)
drain_timeout
=
body
.
get
(
"drain_timeout"
,
120
)
# Default 2 minutes
if
new_data_parallel_size
is
None
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"new_data_parallel_size is required"
)
if
not
isinstance
(
new_data_parallel_size
,
int
)
or
new_data_parallel_size
<=
0
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"new_data_parallel_size must be a positive integer"
)
if
not
isinstance
(
drain_timeout
,
int
)
or
drain_timeout
<=
0
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"drain_timeout must be a positive integer"
)
# Set scaling flag to prevent new requests
global
_scaling_elastic_ep
_scaling_elastic_ep
=
True
client
=
engine_client
(
raw_request
)
try
:
await
client
.
scale_elastic_ep
(
new_data_parallel_size
,
drain_timeout
)
return
JSONResponse
(
{
"message"
:
f
"Scaled to
{
new_data_parallel_size
}
data parallel engines"
,
}
)
except
TimeoutError
as
e
:
raise
HTTPException
(
status_code
=
408
,
detail
=
"Scale failed due to request drain timeout "
f
"after
{
drain_timeout
}
seconds"
,
)
from
e
except
Exception
as
e
:
logger
.
error
(
"Scale failed: %s"
,
e
)
raise
HTTPException
(
status_code
=
500
,
detail
=
"Scale failed"
)
from
e
finally
:
_scaling_elastic_ep
=
False
@
router
.
post
(
"/is_scaling_elastic_ep"
)
async
def
is_scaling_elastic_ep
(
raw_request
:
Request
):
return
JSONResponse
({
"is_scaling_elastic_ep"
:
_scaling_elastic_ep
})
@
router
.
post
(
"/inference/v1/generate"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
OK
.
value
:
{
"content"
:
{
"text/event-stream"
:
{}}},
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_FOUND
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
@
with_cancellation
@
load_aware_call
async
def
generate
(
request
:
GenerateRequest
,
raw_request
:
Request
):
handler
=
generate_tokens
(
raw_request
)
if
handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
message
=
"The model does not support generate tokens API"
)
try
:
generator
=
await
handler
.
serve_tokens
(
request
,
raw_request
)
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
detail
=
str
(
e
)
)
from
e
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
GenerateResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
if
envs
.
VLLM_TORCH_PROFILER_DIR
:
logger
.
warning_once
(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
elif
envs
.
VLLM_TORCH_CUDA_PROFILE
:
logger
.
warning_once
(
"CUDA Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
if
envs
.
VLLM_TORCH_PROFILER_DIR
or
envs
.
VLLM_TORCH_CUDA_PROFILE
:
@
router
.
post
(
"/start_profile"
)
async
def
start_profile
(
raw_request
:
Request
):
logger
.
info
(
"Starting profiler..."
)
await
engine_client
(
raw_request
).
start_profile
()
logger
.
info
(
"Profiler started."
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/stop_profile"
)
async
def
stop_profile
(
raw_request
:
Request
):
logger
.
info
(
"Stopping profiler..."
)
await
engine_client
(
raw_request
).
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
return
Response
(
status_code
=
200
)
def
load_log_config
(
log_config_file
:
str
|
None
)
->
dict
|
None
:
if
not
log_config_file
:
return
None
...
...
@@ -1176,41 +809,6 @@ class XRequestIdMiddleware:
return
self
.
app
(
scope
,
receive
,
send_with_request_id
)
# Global variable to track scaling state
_scaling_elastic_ep
=
False
class
ScalingMiddleware
:
"""
Middleware that checks if the model is currently scaling and
returns a 503 Service Unavailable response if it is.
This middleware applies to all HTTP requests and prevents
processing when the model is in a scaling state.
"""
def
__init__
(
self
,
app
:
ASGIApp
)
->
None
:
self
.
app
=
app
def
__call__
(
self
,
scope
:
Scope
,
receive
:
Receive
,
send
:
Send
)
->
Awaitable
[
None
]:
if
scope
[
"type"
]
!=
"http"
:
return
self
.
app
(
scope
,
receive
,
send
)
# Check global scaling state
global
_scaling_elastic_ep
if
_scaling_elastic_ep
:
# Return 503 Service Unavailable response
response
=
JSONResponse
(
content
=
{
"error"
:
"The model is currently scaling. Please try again later."
},
status_code
=
503
,
)
return
response
(
scope
,
receive
,
send
)
return
self
.
app
(
scope
,
receive
,
send
)
def
_extract_content_from_chunk
(
chunk_data
:
dict
)
->
str
:
"""Extract content from a streaming response chunk."""
try
:
...
...
@@ -1353,15 +951,10 @@ def build_app(args: Namespace) -> FastAPI:
)
else
:
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
state
.
args
=
args
from
vllm.entrypoints.serve
import
register_vllm_serve_api_routers
if
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
logger
.
warning
(
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
from
vllm.entrypoints.dynamic_lora
import
register_dynamic_lora_routes
register_dynamic_lora_routes
(
router
)
register_vllm_serve_api_routers
(
app
)
from
vllm.entrypoints.sagemaker.routes
import
register_sagemaker_routes
...
...
@@ -1370,8 +963,6 @@ def build_app(args: Namespace) -> FastAPI:
app
.
root_path
=
args
.
root_path
mount_metrics
(
app
)
from
vllm.entrypoints.pooling
import
register_pooling_api_routers
register_pooling_api_routers
(
app
)
...
...
@@ -1462,31 +1053,6 @@ def build_app(args: Namespace) -> FastAPI:
)
app
=
sagemaker_standards
.
bootstrap
(
app
)
# Optional endpoints
if
args
.
tokens_only
:
@
app
.
post
(
"/abort_requests"
)
async
def
abort_requests
(
raw_request
:
Request
):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try
:
body
=
await
raw_request
.
json
()
except
json
.
JSONDecodeError
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
f
"JSON decode error:
{
e
}
"
,
)
from
e
request_ids
=
body
.
get
(
"request_ids"
)
if
request_ids
is
None
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
"Missing 'request_ids' in request body"
,
)
# Abort requests in background
asyncio
.
create_task
(
engine_client
(
raw_request
).
abort
(
request_ids
))
return
Response
(
status_code
=
200
)
return
app
...
...
@@ -1515,7 +1081,7 @@ async def init_app_state(
state
.
engine_client
=
engine_client
state
.
log_stats
=
not
args
.
disable_log_stats
state
.
vllm_config
=
vllm_config
state
.
args
=
args
supported_tasks
=
await
engine_client
.
get_supported_tasks
()
logger
.
info
(
"Supported tasks: %s"
,
supported_tasks
)
...
...
@@ -1839,7 +1405,6 @@ async def run_server_worker(
args
,
client_config
=
client_config
,
)
as
engine_client
:
maybe_register_tokenizer_info_endpoint
(
args
)
app
=
build_app
(
args
)
await
init_app_state
(
engine_client
,
app
.
state
,
args
)
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
3f42b05f
...
...
@@ -74,8 +74,6 @@ from vllm.entrypoints.openai.protocol import (
ErrorResponse
,
FunctionCall
,
FunctionDefinition
,
GenerateRequest
,
GenerateResponse
,
ResponsesRequest
,
TokenizeChatRequest
,
TokenizeCompletionRequest
,
...
...
@@ -87,6 +85,7 @@ from vllm.entrypoints.openai.protocol import (
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.entrypoints.renderer
import
BaseRenderer
,
CompletionRenderer
,
RenderConfig
from
vllm.entrypoints.serve.disagg.protocol
import
GenerateRequest
,
GenerateResponse
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.data
import
TokensPrompt
as
EngineTokensPrompt
...
...
vllm/entrypoints/sagemaker/routes.py
View file @
3f42b05f
...
...
@@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import (
completion
,
create_chat_completion
,
create_completion
,
health
,
validate_json_request
,
)
from
vllm.entrypoints.openai.protocol
import
(
...
...
@@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import (
score
,
)
from
vllm.entrypoints.pooling.score.protocol
import
RerankRequest
,
ScoreRequest
from
vllm.entrypoints.serve.instrumentator.health
import
health
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
...
...
vllm/entrypoints/serve/__init__.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
fastapi
import
FastAPI
def
register_vllm_serve_api_routers
(
app
:
FastAPI
):
from
vllm.entrypoints.serve.lora.api_router
import
(
attach_router
as
attach_lora_router
,
)
attach_lora_router
(
app
)
from
vllm.entrypoints.serve.elastic_ep.api_router
import
(
attach_router
as
attach_elastic_ep_router
,
)
attach_elastic_ep_router
(
app
)
from
vllm.entrypoints.serve.profile.api_router
import
(
attach_router
as
attach_profile_router
,
)
attach_profile_router
(
app
)
from
vllm.entrypoints.serve.sleep.api_router
import
(
attach_router
as
attach_sleep_router
,
)
attach_sleep_router
(
app
)
from
vllm.entrypoints.serve.tokenize.api_router
import
(
attach_router
as
attach_tokenize_router
,
)
attach_tokenize_router
(
app
)
from
vllm.entrypoints.serve.disagg.api_router
import
(
attach_router
as
attach_disagg_router
,
)
attach_disagg_router
(
app
)
from
vllm.entrypoints.serve.rlhf.api_router
import
(
attach_router
as
attach_rlhf_router
,
)
attach_rlhf_router
(
app
)
from
vllm.entrypoints.serve.instrumentator.metrics
import
(
attach_router
as
attach_metrics_router
,
)
attach_metrics_router
(
app
)
from
vllm.entrypoints.serve.instrumentator.health
import
(
attach_router
as
attach_health_router
,
)
attach_health_router
(
app
)
vllm/entrypoints/serve/disagg/__init__.py
0 → 100644
View file @
3f42b05f
vllm/entrypoints/serve/disagg/api_router.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
json
from
http
import
HTTPStatus
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
HTTPException
,
Request
,
Response
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.openai.api_server
import
validate_json_request
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
)
from
vllm.entrypoints.serve.disagg.protocol
import
(
GenerateRequest
,
GenerateResponse
,
)
from
vllm.entrypoints.serve.disagg.serving
import
(
ServingTokens
,
)
from
vllm.entrypoints.serve.tokenize.serving
import
OpenAIServingTokenization
from
vllm.entrypoints.utils
import
(
load_aware_call
,
with_cancellation
,
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
tokenization
(
request
:
Request
)
->
OpenAIServingTokenization
:
return
request
.
app
.
state
.
openai_serving_tokenization
def
generate_tokens
(
request
:
Request
)
->
ServingTokens
|
None
:
return
request
.
app
.
state
.
serving_tokens
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
router
=
APIRouter
()
@
router
.
post
(
"/inference/v1/generate"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
OK
.
value
:
{
"content"
:
{
"text/event-stream"
:
{}}},
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_FOUND
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
@
with_cancellation
@
load_aware_call
async
def
generate
(
request
:
GenerateRequest
,
raw_request
:
Request
):
handler
=
generate_tokens
(
raw_request
)
if
handler
is
None
:
return
tokenization
(
raw_request
).
create_error_response
(
message
=
"The model does not support generate tokens API"
)
try
:
generator
=
await
handler
.
serve_tokens
(
request
,
raw_request
)
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
detail
=
str
(
e
)
)
from
e
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
GenerateResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
def
attach_router
(
app
:
FastAPI
):
if
getattr
(
app
.
state
.
args
,
"tokens_only"
,
False
):
@
router
.
post
(
"/abort_requests"
)
async
def
abort_requests
(
raw_request
:
Request
):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try
:
body
=
await
raw_request
.
json
()
except
json
.
JSONDecodeError
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
f
"JSON decode error:
{
e
}
"
,
)
from
e
request_ids
=
body
.
get
(
"request_ids"
)
if
request_ids
is
None
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
"Missing 'request_ids' in request body"
,
)
# Abort requests in background
asyncio
.
create_task
(
engine_client
(
raw_request
).
abort
(
request_ids
))
return
Response
(
status_code
=
200
)
app
.
include_router
(
router
)
vllm/entrypoints/serve/disagg/protocol.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
from
pydantic
import
BaseModel
,
Field
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionLogProbs
,
Logprob
,
SamplingParams
,
StreamOptions
,
)
from
vllm.utils
import
random_uuid
####### Tokens IN <> Tokens OUT #######
class
GenerateRequest
(
BaseModel
):
request_id
:
str
=
Field
(
default_factory
=
lambda
:
f
"
{
random_uuid
()
}
"
,
description
=
(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids
:
list
[
int
]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features
:
str
|
None
=
None
"""The processed MM inputs for the model."""
sampling_params
:
SamplingParams
"""The sampling parameters for the model."""
model
:
str
|
None
=
None
stream
:
bool
|
None
=
False
stream_options
:
StreamOptions
|
None
=
None
cache_salt
:
str
|
None
=
Field
(
default
=
None
,
description
=
(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params
:
dict
[
str
,
Any
]
|
None
=
Field
(
default
=
None
,
description
=
"KVTransfer parameters used for disaggregated serving."
,
)
class
GenerateResponseChoice
(
BaseModel
):
index
:
int
logprobs
:
ChatCompletionLogProbs
|
None
=
None
# per OpenAI spec this is the default
finish_reason
:
str
|
None
=
"stop"
token_ids
:
list
[
int
]
|
None
=
None
class
GenerateResponse
(
BaseModel
):
request_id
:
str
=
Field
(
default_factory
=
lambda
:
f
"
{
random_uuid
()
}
"
,
description
=
(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
choices
:
list
[
GenerateResponseChoice
]
prompt_logprobs
:
list
[
dict
[
int
,
Logprob
]
|
None
]
|
None
=
None
kv_transfer_params
:
dict
[
str
,
Any
]
|
None
=
Field
(
default
=
None
,
description
=
"KVTransfer parameters used for disaggregated serving."
,
)
vllm/entrypoints/
openai
/serving
_tokens
.py
→
vllm/entrypoints/
serve/disagg
/serving.py
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
time
from
collections.abc
import
AsyncGenerator
...
...
@@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
ErrorResponse
,
GenerateRequest
,
GenerateResponse
,
GenerateResponseChoice
,
PromptTokenUsageInfo
,
RequestResponseMetadata
,
UsageInfo
,
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
,
clamp_prompt_logprobs
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.serve.disagg.protocol
import
(
GenerateRequest
,
GenerateResponse
,
GenerateResponseChoice
,
)
from
vllm.inputs.data
import
TokensPrompt
as
EngineTokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
...
...
vllm/entrypoints/serve/elastic_ep/__init__.py
0 → 100644
View file @
3f42b05f
vllm/entrypoints/serve/elastic_ep/api_router.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
http
import
HTTPStatus
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
HTTPException
,
Request
from
fastapi.responses
import
JSONResponse
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.openai.api_server
import
validate_json_request
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
)
from
vllm.entrypoints.serve.elastic_ep.middleware
import
(
get_scaling_elastic_ep
,
set_scaling_elastic_ep
,
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
router
=
APIRouter
()
@
router
.
post
(
"/scale_elastic_ep"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
OK
.
value
:
{
"model"
:
dict
},
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
REQUEST_TIMEOUT
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
async
def
scale_elastic_ep
(
raw_request
:
Request
):
try
:
body
=
await
raw_request
.
json
()
except
json
.
JSONDecodeError
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"Invalid JSON format"
)
from
e
# noqa: B904
new_data_parallel_size
=
body
.
get
(
"new_data_parallel_size"
)
drain_timeout
=
body
.
get
(
"drain_timeout"
,
120
)
# Default 2 minutes
if
new_data_parallel_size
is
None
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"new_data_parallel_size is required"
)
if
not
isinstance
(
new_data_parallel_size
,
int
)
or
new_data_parallel_size
<=
0
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"new_data_parallel_size must be a positive integer"
,
)
if
not
isinstance
(
drain_timeout
,
int
)
or
drain_timeout
<=
0
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"drain_timeout must be a positive integer"
)
# Set scaling flag to prevent new requests
set_scaling_elastic_ep
(
True
)
client
=
engine_client
(
raw_request
)
try
:
await
client
.
scale_elastic_ep
(
new_data_parallel_size
,
drain_timeout
)
return
JSONResponse
(
{
"message"
:
f
"Scaled to
{
new_data_parallel_size
}
data parallel engines"
,
}
)
except
TimeoutError
as
e
:
raise
HTTPException
(
status_code
=
408
,
detail
=
"Scale failed due to request drain timeout "
f
"after
{
drain_timeout
}
seconds"
,
)
from
e
except
Exception
as
e
:
logger
.
error
(
"Scale failed: %s"
,
e
)
raise
HTTPException
(
status_code
=
500
,
detail
=
"Scale failed"
)
from
e
finally
:
set_scaling_elastic_ep
(
False
)
@
router
.
post
(
"/is_scaling_elastic_ep"
)
async
def
is_scaling_elastic_ep
(
raw_request
:
Request
):
return
JSONResponse
({
"is_scaling_elastic_ep"
:
get_scaling_elastic_ep
()})
def
attach_router
(
app
:
FastAPI
):
app
.
include_router
(
router
)
vllm/entrypoints/serve/elastic_ep/middleware.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Awaitable
from
fastapi.responses
import
JSONResponse
from
starlette.types
import
ASGIApp
,
Receive
,
Scope
,
Send
# Global variable to track scaling state
_scaling_elastic_ep
=
False
def
get_scaling_elastic_ep
():
return
_scaling_elastic_ep
def
set_scaling_elastic_ep
(
value
):
global
_scaling_elastic_ep
_scaling_elastic_ep
=
value
class
ScalingMiddleware
:
"""
Middleware that checks if the model is currently scaling and
returns a 503 Service Unavailable response if it is.
This middleware applies to all HTTP requests and prevents
processing when the model is in a scaling state.
"""
def
__init__
(
self
,
app
:
ASGIApp
)
->
None
:
self
.
app
=
app
def
__call__
(
self
,
scope
:
Scope
,
receive
:
Receive
,
send
:
Send
)
->
Awaitable
[
None
]:
if
scope
[
"type"
]
!=
"http"
:
return
self
.
app
(
scope
,
receive
,
send
)
# Check global scaling state
if
get_scaling_elastic_ep
():
# Return 503 Service Unavailable response
response
=
JSONResponse
(
content
=
{
"error"
:
"The model is currently scaling. Please try again later."
},
status_code
=
503
,
)
return
response
(
scope
,
receive
,
send
)
return
self
.
app
(
scope
,
receive
,
send
)
vllm/entrypoints/serve/instrumentator/__init__.py
0 → 100644
View file @
3f42b05f
vllm/entrypoints/serve/instrumentator/health.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
fastapi
import
APIRouter
,
Request
from
fastapi.responses
import
Response
from
vllm.engine.protocol
import
EngineClient
from
vllm.logger
import
init_logger
from
vllm.v1.engine.exceptions
import
EngineDeadError
logger
=
init_logger
(
__name__
)
router
=
APIRouter
()
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
@
router
.
get
(
"/health"
,
response_class
=
Response
)
async
def
health
(
raw_request
:
Request
)
->
Response
:
"""Health check."""
try
:
await
engine_client
(
raw_request
).
check_health
()
return
Response
(
status_code
=
200
)
except
EngineDeadError
:
return
Response
(
status_code
=
503
)
def
attach_router
(
app
):
app
.
include_router
(
router
)
vllm/entrypoints/serve/instrumentator/metrics.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
re
import
prometheus_client
from
fastapi
import
FastAPI
,
Response
from
prometheus_client
import
make_asgi_app
from
prometheus_fastapi_instrumentator
import
Instrumentator
from
starlette.routing
import
Mount
from
vllm.v1.metrics.prometheus
import
get_prometheus_registry
class
PrometheusResponse
(
Response
):
media_type
=
prometheus_client
.
CONTENT_TYPE_LATEST
def
attach_router
(
app
:
FastAPI
):
"""Mount prometheus metrics to a FastAPI app."""
registry
=
get_prometheus_registry
()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# instead of the default "application/json" which is incorrect.
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
Instrumentator
(
excluded_handlers
=
[
"/metrics"
,
"/health"
,
"/load"
,
"/ping"
,
"/version"
,
"/server_info"
,
],
registry
=
registry
,
).
add
().
instrument
(
app
).
expose
(
app
,
response_class
=
PrometheusResponse
)
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
"^/metrics(?P<path>.*)$"
)
app
.
routes
.
append
(
metrics_route
)
vllm/entrypoints/serve/lora/__init__.py
0 → 100644
View file @
3f42b05f
vllm/entrypoints/
dynamic_lora
.py
→
vllm/entrypoints/
serve/lora/api_router
.py
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
model_hosting_container_standards.sagemaker
as
sagemaker_standards
from
fastapi
import
APIRouter
,
Depends
,
Request
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
Response
from
vllm
import
envs
from
vllm.entrypoints.openai.api_server
import
models
,
validate_json_request
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
...
...
@@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
router
=
APIRouter
()
def
register_dynamic_lora_routes
(
router
:
APIRouter
):
def
attach_router
(
app
:
FastAPI
):
if
not
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
"""If LoRA dynamic loading & unloading is not enabled, do nothing."""
return
logger
.
warning
(
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
@
sagemaker_standards
.
register_load_adapter_handler
(
request_shape
=
{
"lora_name"
:
"body.name"
,
...
...
@@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter):
return
Response
(
status_code
=
200
,
content
=
response
)
return
router
# register the router
app
.
include_router
(
router
)
vllm/entrypoints/serve/profile/__init__.py
0 → 100644
View file @
3f42b05f
vllm/entrypoints/serve/profile/api_router.py
0 → 100644
View file @
3f42b05f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
fastapi
import
APIRouter
,
FastAPI
,
Request
from
fastapi.responses
import
Response
import
vllm.envs
as
envs
from
vllm.engine.protocol
import
EngineClient
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
router
=
APIRouter
()
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
@
router
.
post
(
"/start_profile"
)
async
def
start_profile
(
raw_request
:
Request
):
logger
.
info
(
"Starting profiler..."
)
await
engine_client
(
raw_request
).
start_profile
()
logger
.
info
(
"Profiler started."
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/stop_profile"
)
async
def
stop_profile
(
raw_request
:
Request
):
logger
.
info
(
"Stopping profiler..."
)
await
engine_client
(
raw_request
).
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
return
Response
(
status_code
=
200
)
def
attach_router
(
app
:
FastAPI
):
if
envs
.
VLLM_TORCH_PROFILER_DIR
:
logger
.
warning_once
(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
elif
envs
.
VLLM_TORCH_CUDA_PROFILE
:
logger
.
warning_once
(
"CUDA Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
if
envs
.
VLLM_TORCH_PROFILER_DIR
or
envs
.
VLLM_TORCH_CUDA_PROFILE
:
app
.
include_router
(
router
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment