Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8db1b9d0
Unverified
Commit
8db1b9d0
authored
Feb 22, 2025
by
Keyun Tong
Committed by
GitHub
Feb 22, 2025
Browse files
Support SSL Key Rotation in HTTP Server (#13495)
parent
2382ad29
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
173 additions
and
2 deletions
+173
-2
requirements-common.txt
requirements-common.txt
+2
-1
tests/entrypoints/test_ssl_cert_refresher.py
tests/entrypoints/test_ssl_cert_refresher.py
+72
-0
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+6
-0
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+13
-1
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+1
-0
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+5
-0
vllm/entrypoints/ssl.py
vllm/entrypoints/ssl.py
+74
-0
No files found.
requirements-common.txt
View file @
8db1b9d0
...
@@ -37,3 +37,4 @@ einops # Required for Qwen2-VL.
...
@@ -37,3 +37,4 @@ einops # Required for Qwen2-VL.
compressed-tensors == 0.9.2 # required for compressed-tensors
compressed-tensors == 0.9.2 # required for compressed-tensors
depyf==0.18.0 # required for profiling and debugging with compilation config
depyf==0.18.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files
tests/entrypoints/test_ssl_cert_refresher.py
0 → 100644
View file @
8db1b9d0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
tempfile
from
pathlib
import
Path
from
ssl
import
SSLContext
import
pytest
from
vllm.entrypoints.ssl
import
SSLCertRefresher
class
MockSSLContext
(
SSLContext
):
def
__init__
(
self
):
self
.
load_cert_chain_count
=
0
self
.
load_ca_count
=
0
def
load_cert_chain
(
self
,
certfile
,
keyfile
=
None
,
password
=
None
,
):
self
.
load_cert_chain_count
+=
1
def
load_verify_locations
(
self
,
cafile
=
None
,
capath
=
None
,
cadata
=
None
,
):
self
.
load_ca_count
+=
1
def
create_file
()
->
str
:
with
tempfile
.
NamedTemporaryFile
(
dir
=
'/tmp'
,
delete
=
False
)
as
f
:
return
f
.
name
def
touch_file
(
path
:
str
)
->
None
:
Path
(
path
).
touch
()
@
pytest
.
mark
.
asyncio
async
def
test_ssl_refresher
():
ssl_context
=
MockSSLContext
()
key_path
=
create_file
()
cert_path
=
create_file
()
ca_path
=
create_file
()
ssl_refresher
=
SSLCertRefresher
(
ssl_context
,
key_path
,
cert_path
,
ca_path
)
await
asyncio
.
sleep
(
1
)
assert
ssl_context
.
load_cert_chain_count
==
0
assert
ssl_context
.
load_ca_count
==
0
touch_file
(
key_path
)
await
asyncio
.
sleep
(
1
)
assert
ssl_context
.
load_cert_chain_count
==
1
assert
ssl_context
.
load_ca_count
==
0
touch_file
(
cert_path
)
touch_file
(
ca_path
)
await
asyncio
.
sleep
(
1
)
assert
ssl_context
.
load_cert_chain_count
==
2
assert
ssl_context
.
load_ca_count
==
1
ssl_refresher
.
stop
()
touch_file
(
cert_path
)
touch_file
(
ca_path
)
await
asyncio
.
sleep
(
1
)
assert
ssl_context
.
load_cert_chain_count
==
2
assert
ssl_context
.
load_ca_count
==
1
vllm/entrypoints/api_server.py
View file @
8db1b9d0
...
@@ -128,6 +128,7 @@ async def run_server(args: Namespace,
...
@@ -128,6 +128,7 @@ async def run_server(args: Namespace,
shutdown_task
=
await
serve_http
(
shutdown_task
=
await
serve_http
(
app
,
app
,
sock
=
None
,
sock
=
None
,
enable_ssl_refresh
=
args
.
enable_ssl_refresh
,
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
log_level
=
args
.
log_level
,
...
@@ -152,6 +153,11 @@ if __name__ == "__main__":
...
@@ -152,6 +153,11 @@ if __name__ == "__main__":
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
"The CA certificates file"
)
help
=
"The CA certificates file"
)
parser
.
add_argument
(
"--enable-ssl-refresh"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Refresh SSL Context when SSL certificate files change"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--ssl-cert-reqs"
,
"--ssl-cert-reqs"
,
type
=
int
,
type
=
int
,
...
...
vllm/entrypoints/launcher.py
View file @
8db1b9d0
...
@@ -12,13 +12,16 @@ from fastapi import FastAPI, Request, Response
...
@@ -12,13 +12,16 @@ from fastapi import FastAPI, Request, Response
from
vllm
import
envs
from
vllm
import
envs
from
vllm.engine.async_llm_engine
import
AsyncEngineDeadError
from
vllm.engine.async_llm_engine
import
AsyncEngineDeadError
from
vllm.engine.multiprocessing
import
MQEngineDeadError
from
vllm.engine.multiprocessing
import
MQEngineDeadError
from
vllm.entrypoints.ssl
import
SSLCertRefresher
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
find_process_using_port
from
vllm.utils
import
find_process_using_port
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
async
def
serve_http
(
app
:
FastAPI
,
sock
:
Optional
[
socket
.
socket
],
async
def
serve_http
(
app
:
FastAPI
,
sock
:
Optional
[
socket
.
socket
],
enable_ssl_refresh
:
bool
=
False
,
**
uvicorn_kwargs
:
Any
):
**
uvicorn_kwargs
:
Any
):
logger
.
info
(
"Available routes are:"
)
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
for
route
in
app
.
routes
:
...
@@ -31,6 +34,7 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
...
@@ -31,6 +34,7 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
config
.
load
()
server
=
uvicorn
.
Server
(
config
)
server
=
uvicorn
.
Server
(
config
)
_add_shutdown_handlers
(
app
,
server
)
_add_shutdown_handlers
(
app
,
server
)
...
@@ -39,9 +43,17 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
...
@@ -39,9 +43,17 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
server_task
=
loop
.
create_task
(
server_task
=
loop
.
create_task
(
server
.
serve
(
sockets
=
[
sock
]
if
sock
else
None
))
server
.
serve
(
sockets
=
[
sock
]
if
sock
else
None
))
ssl_cert_refresher
=
None
if
not
enable_ssl_refresh
else
SSLCertRefresher
(
ssl_context
=
config
.
ssl
,
key_path
=
config
.
ssl_keyfile
,
cert_path
=
config
.
ssl_certfile
,
ca_path
=
config
.
ssl_ca_certs
)
def
signal_handler
()
->
None
:
def
signal_handler
()
->
None
:
# prevents the uvicorn signal handler to exit early
# prevents the uvicorn signal handler to exit early
server_task
.
cancel
()
server_task
.
cancel
()
if
ssl_cert_refresher
:
ssl_cert_refresher
.
stop
()
async
def
dummy_shutdown
()
->
None
:
async
def
dummy_shutdown
()
->
None
:
pass
pass
...
...
vllm/entrypoints/openai/api_server.py
View file @
8db1b9d0
...
@@ -960,6 +960,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
...
@@ -960,6 +960,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
shutdown_task
=
await
serve_http
(
shutdown_task
=
await
serve_http
(
app
,
app
,
sock
=
sock
,
sock
=
sock
,
enable_ssl_refresh
=
args
.
enable_ssl_refresh
,
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
log_level
=
args
.
uvicorn_log_level
,
log_level
=
args
.
uvicorn_log_level
,
...
...
vllm/entrypoints/openai/cli_args.py
View file @
8db1b9d0
...
@@ -164,6 +164,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
...
@@ -164,6 +164,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type
=
nullable_str
,
type
=
nullable_str
,
default
=
None
,
default
=
None
,
help
=
"The CA certificates file."
)
help
=
"The CA certificates file."
)
parser
.
add_argument
(
"--enable-ssl-refresh"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Refresh SSL Context when SSL certificate files change"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--ssl-cert-reqs"
,
"--ssl-cert-reqs"
,
type
=
int
,
type
=
int
,
...
...
vllm/entrypoints/ssl.py
0 → 100644
View file @
8db1b9d0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
from
ssl
import
SSLContext
from
typing
import
Callable
,
Optional
from
watchfiles
import
Change
,
awatch
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
SSLCertRefresher
:
"""A class that monitors SSL certificate files and
reloads them when they change.
"""
def
__init__
(
self
,
ssl_context
:
SSLContext
,
key_path
:
Optional
[
str
]
=
None
,
cert_path
:
Optional
[
str
]
=
None
,
ca_path
:
Optional
[
str
]
=
None
)
->
None
:
self
.
ssl
=
ssl_context
self
.
key_path
=
key_path
self
.
cert_path
=
cert_path
self
.
ca_path
=
ca_path
# Setup certification chain watcher
def
update_ssl_cert_chain
(
change
:
Change
,
file_path
:
str
)
->
None
:
logger
.
info
(
"Reloading SSL certificate chain"
)
assert
self
.
key_path
and
self
.
cert_path
self
.
ssl
.
load_cert_chain
(
self
.
cert_path
,
self
.
key_path
)
self
.
watch_ssl_cert_task
=
None
if
self
.
key_path
and
self
.
cert_path
:
self
.
watch_ssl_cert_task
=
asyncio
.
create_task
(
self
.
_watch_files
([
self
.
key_path
,
self
.
cert_path
],
update_ssl_cert_chain
))
# Setup CA files watcher
def
update_ssl_ca
(
change
:
Change
,
file_path
:
str
)
->
None
:
logger
.
info
(
"Reloading SSL CA certificates"
)
assert
self
.
ca_path
self
.
ssl
.
load_verify_locations
(
self
.
ca_path
)
self
.
watch_ssl_ca_task
=
None
if
self
.
ca_path
:
self
.
watch_ssl_ca_task
=
asyncio
.
create_task
(
self
.
_watch_files
([
self
.
ca_path
],
update_ssl_ca
))
async
def
_watch_files
(
self
,
paths
,
fun
:
Callable
[[
Change
,
str
],
None
])
->
None
:
"""Watch multiple file paths asynchronously."""
logger
.
info
(
"SSLCertRefresher monitors files: %s"
,
paths
)
async
for
changes
in
awatch
(
*
paths
):
try
:
for
change
,
file_path
in
changes
:
logger
.
info
(
"File change detected: %s - %s"
,
change
.
name
,
file_path
)
fun
(
change
,
file_path
)
except
Exception
as
e
:
logger
.
error
(
"SSLCertRefresher failed taking action on file change. "
"Error: %s"
,
e
)
def
stop
(
self
)
->
None
:
"""Stop watching files."""
if
self
.
watch_ssl_cert_task
:
self
.
watch_ssl_cert_task
.
cancel
()
self
.
watch_ssl_cert_task
=
None
if
self
.
watch_ssl_ca_task
:
self
.
watch_ssl_ca_task
.
cancel
()
self
.
watch_ssl_ca_task
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment