Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc08fc72
Unverified
Commit
cc08fc72
authored
Aug 05, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 04, 2024
Browse files
[Frontend] Reapply "Factor out code for running uvicorn" (#7095)
parent
7b86e7c9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
125 additions
and
82 deletions
+125
-82
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+53
-24
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+46
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+26
-58
No files found.
vllm/entrypoints/api_server.py
View file @
cc08fc72
...
@@ -5,21 +5,23 @@ For production use, we recommend using our OpenAI compatible server.
...
@@ -5,21 +5,23 @@ For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
change `vllm/entrypoints/openai/api_server.py` instead.
"""
"""
import
asyncio
import
json
import
json
import
ssl
import
ssl
from
typing
import
AsyncGenerator
from
argparse
import
Namespace
from
typing
import
Any
,
AsyncGenerator
,
Optional
import
uvicorn
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
...
@@ -81,6 +83,53 @@ async def generate(request: Request) -> Response:
...
@@ -81,6 +83,53 @@ async def generate(request: Request) -> Response:
return
JSONResponse
(
ret
)
return
JSONResponse
(
ret
)
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
global
app
app
.
root_path
=
args
.
root_path
return
app
async
def
init_app
(
args
:
Namespace
,
llm_engine
:
Optional
[
AsyncLLMEngine
]
=
None
,
)
->
FastAPI
:
app
=
build_app
(
args
)
global
engine
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
(
llm_engine
if
llm_engine
is
not
None
else
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
API_SERVER
))
return
app
async
def
run_server
(
args
:
Namespace
,
llm_engine
:
Optional
[
AsyncLLMEngine
]
=
None
,
**
uvicorn_kwargs
:
Any
)
->
None
:
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
app
=
await
init_app
(
args
,
llm_engine
)
shutdown_task
=
await
serve_http
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
,
**
uvicorn_kwargs
,
)
await
shutdown_task
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
)
...
@@ -105,25 +154,5 @@ if __name__ == "__main__":
...
@@ -105,25 +154,5 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--log-level"
,
type
=
str
,
default
=
"debug"
)
parser
.
add_argument
(
"--log-level"
,
type
=
str
,
default
=
"debug"
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
API_SERVER
)
app
.
root_path
=
args
.
root_path
logger
.
info
(
"Available routes are:"
)
asyncio
.
run
(
run_server
(
args
))
for
route
in
app
.
routes
:
if
not
hasattr
(
route
,
'methods'
):
continue
methods
=
', '
.
join
(
route
.
methods
)
logger
.
info
(
"Route: %s, Methods: %s"
,
route
.
path
,
methods
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
)
vllm/entrypoints/launcher.py
0 → 100644
View file @
cc08fc72
import
asyncio
import
signal
from
typing
import
Any
import
uvicorn
from
fastapi
import
FastAPI
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
async
def
serve_http
(
app
:
FastAPI
,
**
uvicorn_kwargs
:
Any
):
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
methods
=
getattr
(
route
,
"methods"
,
None
)
path
=
getattr
(
route
,
"path"
,
None
)
if
methods
is
None
or
path
is
None
:
continue
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
server
=
uvicorn
.
Server
(
config
)
loop
=
asyncio
.
get_running_loop
()
server_task
=
loop
.
create_task
(
server
.
serve
())
def
signal_handler
()
->
None
:
# prevents the uvicorn signal handler to exit early
server_task
.
cancel
()
async
def
dummy_shutdown
()
->
None
:
pass
loop
.
add_signal_handler
(
signal
.
SIGINT
,
signal_handler
)
loop
.
add_signal_handler
(
signal
.
SIGTERM
,
signal_handler
)
try
:
await
server_task
return
dummy_shutdown
()
except
asyncio
.
CancelledError
:
logger
.
info
(
"Gracefully stopping http server"
)
return
server
.
shutdown
()
vllm/entrypoints/openai/api_server.py
View file @
cc08fc72
...
@@ -2,15 +2,13 @@ import asyncio
...
@@ -2,15 +2,13 @@ import asyncio
import
importlib
import
importlib
import
inspect
import
inspect
import
re
import
re
import
signal
from
argparse
import
Namespace
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
multiprocessing
import
Process
from
multiprocessing
import
Process
from
typing
import
AsyncIterator
,
Set
from
typing
import
AsyncIterator
,
Set
import
fastapi
from
fastapi
import
APIRouter
,
FastAPI
,
Request
import
uvicorn
from
fastapi
import
APIRouter
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
...
@@ -22,6 +20,7 @@ from vllm.config import ModelConfig
...
@@ -22,6 +20,7 @@ from vllm.config import ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
...
@@ -71,7 +70,7 @@ def model_is_embedding(model_name: str) -> bool:
...
@@ -71,7 +70,7 @@ def model_is_embedding(model_name: str) -> bool:
@
asynccontextmanager
@
asynccontextmanager
async
def
lifespan
(
app
:
fastapi
.
FastAPI
):
async
def
lifespan
(
app
:
FastAPI
):
async
def
_force_log
():
async
def
_force_log
():
while
True
:
while
True
:
...
@@ -135,7 +134,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
...
@@ -135,7 +134,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
router
=
APIRouter
()
router
=
APIRouter
()
def
mount_metrics
(
app
:
fastapi
.
FastAPI
):
def
mount_metrics
(
app
:
FastAPI
):
# Add prometheus asgi middleware to route /metrics requests
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
())
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
())
# Workaround for 307 Redirect for /metrics
# Workaround for 307 Redirect for /metrics
...
@@ -225,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
...
@@ -225,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
JSONResponse
(
content
=
generator
.
model_dump
())
def
build_app
(
args
)
:
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
app
=
fastapi
.
FastAPI
(
lifespan
=
lifespan
)
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
include_router
(
router
)
app
.
include_router
(
router
)
app
.
root_path
=
args
.
root_path
app
.
root_path
=
args
.
root_path
...
@@ -274,11 +273,10 @@ def build_app(args):
...
@@ -274,11 +273,10 @@ def build_app(args):
return
app
return
app
async
def
build_server
(
async
def
init_app
(
async_engine_client
:
AsyncEngineClient
,
async_engine_client
:
AsyncEngineClient
,
args
,
args
:
Namespace
,
**
uvicorn_kwargs
,
)
->
FastAPI
:
)
->
uvicorn
.
Server
:
app
=
build_app
(
args
)
app
=
build_app
(
args
)
if
args
.
served_model_name
is
not
None
:
if
args
.
served_model_name
is
not
None
:
...
@@ -334,14 +332,17 @@ async def build_server(
...
@@ -334,14 +332,17 @@ async def build_server(
)
)
app
.
root_path
=
args
.
root_path
app
.
root_path
=
args
.
root_path
logger
.
info
(
"Available routes are:"
)
return
app
for
route
in
app
.
routes
:
if
not
hasattr
(
route
,
'methods'
):
continue
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
methods
=
', '
.
join
(
route
.
methods
)
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"Route: %s, Methods: %s"
,
route
.
path
,
methods
)
logger
.
info
(
"args: %s"
,
args
)
async
with
build_async_engine_client
(
args
)
as
async_engine_client
:
app
=
await
init_app
(
async_engine_client
,
args
)
config
=
uvicorn
.
Config
(
shutdown_task
=
await
serve_http
(
app
,
app
,
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
...
@@ -354,40 +355,6 @@ async def build_server(
...
@@ -354,40 +355,6 @@ async def build_server(
**
uvicorn_kwargs
,
**
uvicorn_kwargs
,
)
)
return
uvicorn
.
Server
(
config
)
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
shutdown_task
=
None
async
with
build_async_engine_client
(
args
)
as
async_engine_client
:
server
=
await
build_server
(
async_engine_client
,
args
,
**
uvicorn_kwargs
,
)
loop
=
asyncio
.
get_running_loop
()
server_task
=
loop
.
create_task
(
server
.
serve
())
def
signal_handler
()
->
None
:
# prevents the uvicorn signal handler to exit early
server_task
.
cancel
()
loop
.
add_signal_handler
(
signal
.
SIGINT
,
signal_handler
)
loop
.
add_signal_handler
(
signal
.
SIGTERM
,
signal_handler
)
try
:
await
server_task
except
asyncio
.
CancelledError
:
logger
.
info
(
"Gracefully stopping http server"
)
shutdown_task
=
server
.
shutdown
()
if
shutdown_task
:
# NB: Await server shutdown only after the backend context is exited
# NB: Await server shutdown only after the backend context is exited
await
shutdown_task
await
shutdown_task
...
@@ -399,4 +366,5 @@ if __name__ == "__main__":
...
@@ -399,4 +366,5 @@ if __name__ == "__main__":
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
parser
=
make_arg_parser
(
parser
)
parser
=
make_arg_parser
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
asyncio
.
run
(
run_server
(
args
))
asyncio
.
run
(
run_server
(
args
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment