Unverified Commit a4bbe492 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(http): TLS support (#2492)

parent 04442173
...@@ -367,6 +367,29 @@ version = "1.4.0" ...@@ -367,6 +367,29 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "aws-lc-rs"
version = "1.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c953fe1ba023e6b7730c0d4b031d06f267f23a46167dcbd40316644b10a17ba"
dependencies = [
"aws-lc-sys",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.30.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff"
dependencies = [
"bindgen 0.69.5",
"cc",
"cmake",
"dunce",
"fs_extra",
]
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.7.9" version = "0.7.9"
...@@ -480,6 +503,28 @@ dependencies = [ ...@@ -480,6 +503,28 @@ dependencies = [
"syn 2.0.100", "syn 2.0.100",
] ]
[[package]]
name = "axum-server"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "495c05f60d6df0093e8fb6e74aa5846a0ad06abaf96d76166283720bf740f8ab"
dependencies = [
"arc-swap",
"bytes",
"fs-err",
"http 1.3.1",
"http-body 1.0.1",
"hyper 1.6.0",
"hyper-util",
"pin-project-lite",
"rustls",
"rustls-pemfile",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]] [[package]]
name = "backoff" name = "backoff"
version = "0.4.0" version = "0.4.0"
...@@ -1787,6 +1832,12 @@ dependencies = [ ...@@ -1787,6 +1832,12 @@ dependencies = [
"dtoa", "dtoa",
] ]
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]] [[package]]
name = "dyn-clone" name = "dyn-clone"
version = "1.0.19" version = "1.0.19"
...@@ -1858,6 +1909,7 @@ dependencies = [ ...@@ -1858,6 +1909,7 @@ dependencies = [
"async-trait", "async-trait",
"async_zmq", "async_zmq",
"axum 0.8.3", "axum 0.8.3",
"axum-server",
"blake3", "blake3",
"bs62", "bs62",
"bytemuck", "bytemuck",
...@@ -1899,6 +1951,7 @@ dependencies = [ ...@@ -1899,6 +1951,7 @@ dependencies = [
"rmp-serde", "rmp-serde",
"rstest 0.18.2", "rstest 0.18.2",
"rstest_reuse", "rstest_reuse",
"rustls",
"serde", "serde",
"serde_json", "serde_json",
"serial_test", "serial_test",
...@@ -2418,6 +2471,22 @@ dependencies = [ ...@@ -2418,6 +2471,22 @@ dependencies = [
"num", "num",
] ]
[[package]]
name = "fs-err"
version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d7be93788013f265201256d58f04936a8079ad5dc898743aa20525f503b683"
dependencies = [
"autocfg",
"tokio",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]] [[package]]
name = "fuchsia-zircon" name = "fuchsia-zircon"
version = "0.3.3" version = "0.3.3"
...@@ -6086,6 +6155,7 @@ version = "0.23.26" ...@@ -6086,6 +6155,7 @@ version = "0.23.26"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0" checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0"
dependencies = [ dependencies = [
"aws-lc-rs",
"log", "log",
"once_cell", "once_cell",
"ring", "ring",
...@@ -6154,6 +6224,7 @@ version = "0.103.1" ...@@ -6154,6 +6224,7 @@ version = "0.103.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03"
dependencies = [ dependencies = [
"aws-lc-rs",
"ring", "ring",
"rustls-pki-types", "rustls-pki-types",
"untrusted", "untrusted",
......
...@@ -118,8 +118,9 @@ Dynamo provides a simple way to spin up a local set of inference components incl ...@@ -118,8 +118,9 @@ Dynamo provides a simple way to spin up a local set of inference components incl
- **Workers** – Set of pre-configured LLM serving engines. - **Workers** – Set of pre-configured LLM serving engines.
``` ```
# Start an OpenAI compatible HTTP server, a pre-processor (prompt templating and tokenization) and a router: # Start an OpenAI compatible HTTP server, a pre-processor (prompt templating and tokenization) and a router.
python -m dynamo.frontend --http-port 8080 # Pass the TLS certificate and key paths to use HTTPS instead of HTTP.
python -m dynamo.frontend --http-port 8080 [--tls-cert-path cert.pem] [--tls-key-path key.pem]
# Start the SGLang engine, connecting to NATS and etcd to receive requests. You can run several of these, # Start the SGLang engine, connecting to NATS and etcd to receive requests. You can run several of these,
# both for the same model and for multiple models. The frontend node will discover them. # both for the same model and for multiple models. The frontend node will discover them.
......
...@@ -16,10 +16,15 @@ ...@@ -16,10 +16,15 @@
# Worker example: # Worker example:
# - cd lib/bindings/python/examples/hello_world # - cd lib/bindings/python/examples/hello_world
# - python server_sglang_static.py # - python server_sglang_static.py
#
# For TLS:
# - python -m dynamo.frontend --http-port 8443 --tls-cert-path cert.pem --tls-key-path key.pem
#
import argparse import argparse
import asyncio import asyncio
import os import os
import pathlib
import re import re
import uvloop import uvloop
...@@ -85,6 +90,18 @@ def parse_args(): ...@@ -85,6 +90,18 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--http-port", type=int, default=8080, help="HTTP port for the engine (u16)." "--http-port", type=int, default=8080, help="HTTP port for the engine (u16)."
) )
parser.add_argument(
"--tls-cert-path",
type=pathlib.Path,
default=None,
help="TLS certificate path, PEM format.",
)
parser.add_argument(
"--tls-key-path",
type=pathlib.Path,
default=None,
help="TLS certificate key path, PEM format.",
)
parser.add_argument( parser.add_argument(
"--router-mode", "--router-mode",
type=str, type=str,
...@@ -149,6 +166,8 @@ def parse_args(): ...@@ -149,6 +166,8 @@ def parse_args():
if flags.static_endpoint and (not flags.model_name or not flags.model_path): if flags.static_endpoint and (not flags.model_name or not flags.model_path):
parser.error("--static-endpoint requires both --model-name and --model-path") parser.error("--static-endpoint requires both --model-name and --model-path")
if bool(flags.tls_cert_path) ^ bool(flags.tls_key_path): # ^ is XOR
parser.error("--tls-cert-path and --tls-key-path must be provided together")
return flags return flags
...@@ -192,6 +211,10 @@ async def async_main(): ...@@ -192,6 +211,10 @@ async def async_main():
kwargs["model_name"] = flags.model_name kwargs["model_name"] = flags.model_name
if flags.model_path: if flags.model_path:
kwargs["model_path"] = flags.model_path kwargs["model_path"] = flags.model_path
if flags.tls_cert_path:
kwargs["tls_cert_path"] = flags.tls_cert_path
if flags.tls_key_path:
kwargs["tls_key_path"] = flags.tls_key_path
if is_static: if is_static:
# out=dyn://<static_endpoint> # out=dyn://<static_endpoint>
......
...@@ -45,9 +45,18 @@ pub struct Flags { ...@@ -45,9 +45,18 @@ pub struct Flags {
pub model_path_flag: Option<PathBuf>, pub model_path_flag: Option<PathBuf>,
/// HTTP port. `in=http` only /// HTTP port. `in=http` only
/// If tls_cert_path and tls_key_path are provided, this will be TLS/HTTPS.
#[arg(long, default_value = "8080")] #[arg(long, default_value = "8080")]
pub http_port: u16, pub http_port: u16,
/// TLS certificate file
#[arg(long, requires = "tls_key_path")]
pub tls_cert_path: Option<PathBuf>,
/// TLS certificate key file
#[arg(long, requires = "tls_cert_path")]
pub tls_key_path: Option<PathBuf>,
/// The name of the model we are serving /// The name of the model we are serving
#[arg(long)] #[arg(long)]
pub model_name: Option<String>, pub model_name: Option<String>,
......
...@@ -20,7 +20,7 @@ pub async fn run( ...@@ -20,7 +20,7 @@ pub async fn run(
runtime: Runtime, runtime: Runtime,
in_opt: Input, in_opt: Input,
out_opt: Option<Output>, out_opt: Option<Output>,
flags: Flags, mut flags: Flags,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// //
// Configure // Configure
...@@ -39,7 +39,9 @@ pub async fn run( ...@@ -39,7 +39,9 @@ pub async fn run(
.kv_cache_block_size(flags.kv_cache_block_size) .kv_cache_block_size(flags.kv_cache_block_size)
// Only set if user provides. Usually loaded from tokenizer_config.json // Only set if user provides. Usually loaded from tokenizer_config.json
.context_length(flags.context_length) .context_length(flags.context_length)
.http_port(Some(flags.http_port)) .http_port(flags.http_port)
.tls_cert_path(flags.tls_cert_path.take())
.tls_key_path(flags.tls_key_path.take())
.router_config(Some(flags.router_config())) .router_config(Some(flags.router_config()))
.request_template(flags.request_template.clone()) .request_template(flags.request_template.clone())
.migration_limit(flags.migration_limit) .migration_limit(flags.migration_limit)
......
...@@ -316,6 +316,29 @@ version = "1.4.0" ...@@ -316,6 +316,29 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "aws-lc-rs"
version = "1.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c953fe1ba023e6b7730c0d4b031d06f267f23a46167dcbd40316644b10a17ba"
dependencies = [
"aws-lc-sys",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.30.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff"
dependencies = [
"bindgen 0.69.5",
"cc",
"cmake",
"dunce",
"fs_extra",
]
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.8.3" version = "0.8.3"
...@@ -382,6 +405,28 @@ dependencies = [ ...@@ -382,6 +405,28 @@ dependencies = [
"syn 2.0.100", "syn 2.0.100",
] ]
[[package]]
name = "axum-server"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "495c05f60d6df0093e8fb6e74aa5846a0ad06abaf96d76166283720bf740f8ab"
dependencies = [
"arc-swap",
"bytes",
"fs-err",
"http",
"http-body",
"hyper",
"hyper-util",
"pin-project-lite",
"rustls",
"rustls-pemfile",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]] [[package]]
name = "backoff" name = "backoff"
version = "0.4.0" version = "0.4.0"
...@@ -429,6 +474,29 @@ version = "1.7.3" ...@@ -429,6 +474,29 @@ version = "1.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3" checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.100",
"which",
]
[[package]] [[package]]
name = "bindgen" name = "bindgen"
version = "0.71.1" version = "0.71.1"
...@@ -444,7 +512,7 @@ dependencies = [ ...@@ -444,7 +512,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"regex", "regex",
"rustc-hash", "rustc-hash 2.1.1",
"shlex", "shlex",
"syn 2.0.100", "syn 2.0.100",
] ]
...@@ -680,6 +748,15 @@ version = "0.7.4" ...@@ -680,6 +748,15 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.3" version = "1.0.3"
...@@ -1113,6 +1190,12 @@ dependencies = [ ...@@ -1113,6 +1190,12 @@ dependencies = [
"pyo3", "pyo3",
] ]
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]] [[package]]
name = "dyn-stack" name = "dyn-stack"
version = "0.10.0" version = "0.10.0"
...@@ -1145,6 +1228,7 @@ dependencies = [ ...@@ -1145,6 +1228,7 @@ dependencies = [
"async-trait", "async-trait",
"async_zmq", "async_zmq",
"axum", "axum",
"axum-server",
"blake3", "blake3",
"bs62", "bs62",
"bytemuck", "bytemuck",
...@@ -1179,6 +1263,7 @@ dependencies = [ ...@@ -1179,6 +1263,7 @@ dependencies = [
"rayon", "rayon",
"regex", "regex",
"rmp-serde", "rmp-serde",
"rustls",
"serde", "serde",
"serde_json", "serde_json",
"strum", "strum",
...@@ -1519,6 +1604,22 @@ dependencies = [ ...@@ -1519,6 +1604,22 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "fs-err"
version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d7be93788013f265201256d58f04936a8079ad5dc898743aa20525f503b683"
dependencies = [
"autocfg",
"tokio",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]] [[package]]
name = "fuchsia-zircon" name = "fuchsia-zircon"
version = "0.3.3" version = "0.3.3"
...@@ -2030,6 +2131,15 @@ dependencies = [ ...@@ -2030,6 +2131,15 @@ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "http" name = "http"
version = "1.3.1" version = "1.3.1"
...@@ -2425,6 +2535,15 @@ version = "1.70.1" ...@@ -2425,6 +2535,15 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.13.0" version = "0.13.0"
...@@ -2495,6 +2614,12 @@ version = "1.5.0" ...@@ -2495,6 +2614,12 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.172" version = "0.2.172"
...@@ -2527,6 +2652,12 @@ dependencies = [ ...@@ -2527,6 +2652,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.9.4" version = "0.9.4"
...@@ -2862,7 +2993,7 @@ version = "0.4.1" ...@@ -2862,7 +2993,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "743ed1038b386b75451f9e0bba37cb2e3eea75873635268337d6531be99c9303" checksum = "743ed1038b386b75451f9e0bba37cb2e3eea75873635268337d6531be99c9303"
dependencies = [ dependencies = [
"bindgen", "bindgen 0.71.1",
"cc", "cc",
"libc", "libc",
"os_info", "os_info",
...@@ -3580,7 +3711,7 @@ dependencies = [ ...@@ -3580,7 +3711,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"quinn-proto", "quinn-proto",
"quinn-udp", "quinn-udp",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"socket2", "socket2",
"thiserror 2.0.12", "thiserror 2.0.12",
...@@ -3599,7 +3730,7 @@ dependencies = [ ...@@ -3599,7 +3730,7 @@ dependencies = [
"getrandom 0.3.2", "getrandom 0.3.2",
"rand 0.9.1", "rand 0.9.1",
"ring", "ring",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"rustls-pki-types", "rustls-pki-types",
"slab", "slab",
...@@ -3936,6 +4067,12 @@ version = "0.1.24" ...@@ -3936,6 +4067,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "2.1.1" version = "2.1.1"
...@@ -3951,6 +4088,19 @@ dependencies = [ ...@@ -3951,6 +4088,19 @@ dependencies = [
"semver", "semver",
] ]
[[package]]
name = "rustix"
version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
"bitflags 2.9.0",
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "1.0.5" version = "1.0.5"
...@@ -3960,7 +4110,7 @@ dependencies = [ ...@@ -3960,7 +4110,7 @@ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.0",
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys 0.9.4",
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
...@@ -3970,6 +4120,7 @@ version = "0.23.26" ...@@ -3970,6 +4120,7 @@ version = "0.23.26"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0" checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0"
dependencies = [ dependencies = [
"aws-lc-rs",
"log", "log",
"once_cell", "once_cell",
"ring", "ring",
...@@ -4038,6 +4189,7 @@ version = "0.103.1" ...@@ -4038,6 +4189,7 @@ version = "0.103.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03"
dependencies = [ dependencies = [
"aws-lc-rs",
"ring", "ring",
"rustls-pki-types", "rustls-pki-types",
"untrusted", "untrusted",
...@@ -4522,7 +4674,7 @@ dependencies = [ ...@@ -4522,7 +4674,7 @@ dependencies = [
"fastrand", "fastrand",
"getrandom 0.3.2", "getrandom 0.3.2",
"once_cell", "once_cell",
"rustix", "rustix 1.0.5",
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
...@@ -5375,6 +5527,18 @@ dependencies = [ ...@@ -5375,6 +5527,18 @@ dependencies = [
"rustls-pki-types", "rustls-pki-types",
] ]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix 0.38.44",
]
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.2.8" version = "0.2.8"
......
...@@ -10,6 +10,7 @@ use dynamo_llm::entrypoint::input::Input; ...@@ -10,6 +10,7 @@ use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig; use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig; use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig; use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder}; use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_llm::mocker::protocols::MockEngineArgs; use dynamo_llm::mocker::protocols::MockEngineArgs;
use dynamo_runtime::protocols::Endpoint as EndpointId; use dynamo_runtime::protocols::Endpoint as EndpointId;
...@@ -94,7 +95,9 @@ pub(crate) struct EntrypointArgs { ...@@ -94,7 +95,9 @@ pub(crate) struct EntrypointArgs {
template_file: Option<PathBuf>, template_file: Option<PathBuf>,
router_config: Option<RouterConfig>, router_config: Option<RouterConfig>,
kv_cache_block_size: Option<u32>, kv_cache_block_size: Option<u32>,
http_port: Option<u16>, http_port: u16,
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
} }
...@@ -102,7 +105,7 @@ pub(crate) struct EntrypointArgs { ...@@ -102,7 +105,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs { impl EntrypointArgs {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[new] #[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_port=None, extra_engine_args=None))] #[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None))]
pub fn new( pub fn new(
engine_type: EngineType, engine_type: EngineType,
model_path: Option<PathBuf>, model_path: Option<PathBuf>,
...@@ -114,6 +117,8 @@ impl EntrypointArgs { ...@@ -114,6 +117,8 @@ impl EntrypointArgs {
router_config: Option<RouterConfig>, router_config: Option<RouterConfig>,
kv_cache_block_size: Option<u32>, kv_cache_block_size: Option<u32>,
http_port: Option<u16>, http_port: Option<u16>,
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
) -> PyResult<Self> { ) -> PyResult<Self> {
let endpoint_id_obj: Option<EndpointId> = match endpoint_id { let endpoint_id_obj: Option<EndpointId> = match endpoint_id {
...@@ -124,6 +129,13 @@ impl EntrypointArgs { ...@@ -124,6 +129,13 @@ impl EntrypointArgs {
})?), })?),
None => None, None => None,
}; };
if (tls_cert_path.is_some() && tls_key_path.is_none())
|| (tls_cert_path.is_none() && tls_key_path.is_some())
{
return Err(pyo3::exceptions::PyValueError::new_err(
"tls_cert_path and tls_key_path must be provided together",
));
}
Ok(EntrypointArgs { Ok(EntrypointArgs {
engine_type, engine_type,
model_path, model_path,
...@@ -134,7 +146,9 @@ impl EntrypointArgs { ...@@ -134,7 +146,9 @@ impl EntrypointArgs {
template_file, template_file,
router_config, router_config,
kv_cache_block_size, kv_cache_block_size,
http_port, http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
tls_cert_path,
tls_key_path,
extra_engine_args, extra_engine_args,
}) })
} }
...@@ -164,6 +178,8 @@ pub fn make_engine<'p>( ...@@ -164,6 +178,8 @@ pub fn make_engine<'p>(
.kv_cache_block_size(args.kv_cache_block_size) .kv_cache_block_size(args.kv_cache_block_size)
.router_config(args.router_config.clone().map(|rc| rc.into())) .router_config(args.router_config.clone().map(|rc| rc.into()))
.http_port(args.http_port) .http_port(args.http_port)
.tls_cert_path(args.tls_cert_path.clone())
.tls_key_path(args.tls_key_path.clone())
.is_mocker(matches!(args.engine_type, EngineType::Mocker)) .is_mocker(matches!(args.engine_type, EngineType::Mocker))
.extra_engine_args(args.extra_engine_args.clone()); .extra_engine_args(args.extra_engine_args.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
......
...@@ -101,7 +101,9 @@ unicode-segmentation = "1.12" ...@@ -101,7 +101,9 @@ unicode-segmentation = "1.12"
# http-service # http-service
axum = { workspace = true } axum = { workspace = true }
axum-server = { version = "0.7", features = ["tls-rustls"] }
tower-http = {workspace = true} tower-http = {workspace = true}
rustls = { version = "0.23" }
# tokenizers # tokenizers
......
...@@ -22,9 +22,31 @@ use dynamo_runtime::{DistributedRuntime, Runtime}; ...@@ -22,9 +22,31 @@ use dynamo_runtime::{DistributedRuntime, Runtime};
/// Build and run an HTTP service /// Build and run an HTTP service
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> { pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
let mut http_service_builder = service_v2::HttpService::builder() let local_model = engine_config.local_model();
.port(engine_config.local_model().http_port()) let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
.with_request_template(engine_config.local_model().request_template()); (Some(tls_cert_path), Some(tls_key_path)) => {
if !tls_cert_path.exists() {
anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
}
if !tls_key_path.exists() {
anyhow::bail!("TLS key not found: {}", tls_key_path.display());
}
service_v2::HttpService::builder()
.enable_tls(true)
.tls_cert_path(Some(tls_cert_path.to_path_buf()))
.tls_key_path(Some(tls_key_path.to_path_buf()))
.port(local_model.http_port())
}
(None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
(_, _) => {
// CLI should prevent us ever getting here
anyhow::bail!(
"Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
);
}
};
http_service_builder =
http_service_builder.with_request_template(engine_config.local_model().request_template());
let http_service = match engine_config { let http_service = match engine_config {
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic(_) => {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::env::var; use std::env::var;
use std::path::PathBuf;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::sync::Arc; use std::sync::Arc;
...@@ -15,9 +16,11 @@ use crate::discovery::ModelManager; ...@@ -15,9 +16,11 @@ use crate::discovery::ModelManager;
use crate::endpoint_type::EndpointType; use crate::endpoint_type::EndpointType;
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use anyhow::Result; use anyhow::Result;
use axum_server::tls_rustls::RustlsConfig;
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::logging::make_request_span; use dynamo_runtime::logging::make_request_span;
use dynamo_runtime::transports::etcd; use dynamo_runtime::transports::etcd;
use std::net::SocketAddr;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
...@@ -126,6 +129,9 @@ pub struct HttpService { ...@@ -126,6 +129,9 @@ pub struct HttpService {
router: axum::Router, router: axum::Router,
port: u16, port: u16,
host: String, host: String,
enable_tls: bool,
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
route_docs: Vec<RouteDoc>, route_docs: Vec<RouteDoc>,
} }
...@@ -138,6 +144,15 @@ pub struct HttpServiceConfig { ...@@ -138,6 +144,15 @@ pub struct HttpServiceConfig {
#[builder(setter(into), default = "String::from(\"0.0.0.0\")")] #[builder(setter(into), default = "String::from(\"0.0.0.0\")")]
host: String, host: String,
#[builder(default = "false")]
enable_tls: bool,
#[builder(default = "None")]
tls_cert_path: Option<PathBuf>,
#[builder(default = "None")]
tls_key_path: Option<PathBuf>,
// #[builder(default)] // #[builder(default)]
// custom: Vec<axum::Router> // custom: Vec<axum::Router>
#[builder(default = "false")] #[builder(default = "false")]
...@@ -183,19 +198,61 @@ impl HttpService { ...@@ -183,19 +198,61 @@ impl HttpService {
pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> { pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> {
let address = format!("{}:{}", self.host, self.port); let address = format!("{}:{}", self.host, self.port);
tracing::info!(address, "Starting HTTP service on: {address}"); let protocol = if self.enable_tls { "HTTPS" } else { "HTTP" };
tracing::info!(protocol, address, "Starting HTTP(S) service");
let listener = tokio::net::TcpListener::bind(address.as_str())
.await
.unwrap_or_else(|_| panic!("could not bind to address: {address}"));
let router = self.router.clone(); let router = self.router.clone();
let observer = cancel_token.child_token(); let observer = cancel_token.child_token();
axum::serve(listener, router) let addr: SocketAddr = address
.with_graceful_shutdown(observer.cancelled_owned()) .parse()
.await .map_err(|e| anyhow::anyhow!("Invalid address '{}': {}", address, e))?;
.inspect_err(|_| cancel_token.cancel())?;
if self.enable_tls {
let cert_path = self
.tls_cert_path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS certificate path not provided"))?;
let key_path = self
.tls_key_path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS private key path not provided"))?;
// aws_lc_rs is the default but other crates pull in `ring` also,
// so rustls doesn't know which one to use. Tell it.
if let Err(e) = rustls::crypto::aws_lc_rs::default_provider().install_default() {
tracing::debug!("TLS crypto provider already installed: {e:?}");
}
let config = RustlsConfig::from_pem_file(cert_path, key_path)
.await
.map_err(|e| anyhow::anyhow!("Failed to create TLS config: {}", e))?;
let handle = axum_server::Handle::new();
let server = axum_server::bind_rustls(addr, config)
.handle(handle.clone())
.serve(router.into_make_service());
tokio::select! {
result = server => {
result.map_err(|e| anyhow::anyhow!("HTTPS server error: {}", e))?;
}
_ = observer.cancelled() => {
tracing::info!("HTTPS server shutdown requested");
handle.graceful_shutdown(Some(Duration::from_secs(5)));
// TODO: Do we need to wait?
}
}
} else {
let listener = tokio::net::TcpListener::bind(addr)
.await
.unwrap_or_else(|_| panic!("could not bind to address: {address}"));
axum::serve(listener, router)
.with_graceful_shutdown(observer.cancelled_owned())
.await
.inspect_err(|_| cancel_token.cancel())?;
}
Ok(()) Ok(())
} }
...@@ -283,6 +340,9 @@ impl HttpServiceConfigBuilder { ...@@ -283,6 +340,9 @@ impl HttpServiceConfigBuilder {
router, router,
port: config.port, port: config.port,
host: config.host, host: config.host,
enable_tls: config.enable_tls,
tls_cert_path: config.tls_cert_path,
tls_key_path: config.tls_key_path,
route_docs: all_docs, route_docs: all_docs,
}) })
} }
......
...@@ -38,7 +38,8 @@ const DEFAULT_NAME: &str = "dynamo"; ...@@ -38,7 +38,8 @@ const DEFAULT_NAME: &str = "dynamo";
const DEFAULT_KV_CACHE_BLOCK_SIZE: u32 = 16; const DEFAULT_KV_CACHE_BLOCK_SIZE: u32 = 16;
/// We can't have it default to 0, so pick something /// We can't have it default to 0, so pick something
const DEFAULT_HTTP_PORT: u16 = 8080; /// 'pub' because the bindings use it for consistency.
pub const DEFAULT_HTTP_PORT: u16 = 8080;
pub struct LocalModelBuilder { pub struct LocalModelBuilder {
model_path: Option<PathBuf>, model_path: Option<PathBuf>,
...@@ -50,6 +51,8 @@ pub struct LocalModelBuilder { ...@@ -50,6 +51,8 @@ pub struct LocalModelBuilder {
router_config: Option<RouterConfig>, router_config: Option<RouterConfig>,
kv_cache_block_size: u32, kv_cache_block_size: u32,
http_port: u16, http_port: u16,
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
migration_limit: u32, migration_limit: u32,
is_mocker: bool, is_mocker: bool,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
...@@ -62,6 +65,8 @@ impl Default for LocalModelBuilder { ...@@ -62,6 +65,8 @@ impl Default for LocalModelBuilder {
LocalModelBuilder { LocalModelBuilder {
kv_cache_block_size: DEFAULT_KV_CACHE_BLOCK_SIZE, kv_cache_block_size: DEFAULT_KV_CACHE_BLOCK_SIZE,
http_port: DEFAULT_HTTP_PORT, http_port: DEFAULT_HTTP_PORT,
tls_cert_path: Default::default(),
tls_key_path: Default::default(),
model_path: Default::default(), model_path: Default::default(),
model_name: Default::default(), model_name: Default::default(),
model_config: Default::default(), model_config: Default::default(),
...@@ -110,9 +115,18 @@ impl LocalModelBuilder { ...@@ -110,9 +115,18 @@ impl LocalModelBuilder {
self self
} }
/// Passing None resets it to default pub fn http_port(&mut self, port: u16) -> &mut Self {
pub fn http_port(&mut self, port: Option<u16>) -> &mut Self { self.http_port = port;
self.http_port = port.unwrap_or(DEFAULT_HTTP_PORT); self
}
pub fn tls_cert_path(&mut self, p: Option<PathBuf>) -> &mut Self {
self.tls_cert_path = p;
self
}
pub fn tls_key_path(&mut self, p: Option<PathBuf>) -> &mut Self {
self.tls_key_path = p;
self self
} }
...@@ -187,6 +201,8 @@ impl LocalModelBuilder { ...@@ -187,6 +201,8 @@ impl LocalModelBuilder {
endpoint_id, endpoint_id,
template, template,
http_port: self.http_port, http_port: self.http_port,
tls_cert_path: self.tls_cert_path.take(),
tls_key_path: self.tls_key_path.take(),
router_config: self.router_config.take().unwrap_or_default(), router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(), runtime_config: self.runtime_config.clone(),
}); });
...@@ -260,6 +276,8 @@ impl LocalModelBuilder { ...@@ -260,6 +276,8 @@ impl LocalModelBuilder {
endpoint_id, endpoint_id,
template, template,
http_port: self.http_port, http_port: self.http_port,
tls_cert_path: self.tls_cert_path.take(),
tls_key_path: self.tls_key_path.take(),
router_config: self.router_config.take().unwrap_or_default(), router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(), runtime_config: self.runtime_config.clone(),
}) })
...@@ -272,7 +290,9 @@ pub struct LocalModel { ...@@ -272,7 +290,9 @@ pub struct LocalModel {
card: ModelDeploymentCard, card: ModelDeploymentCard,
endpoint_id: EndpointId, endpoint_id: EndpointId,
template: Option<RequestTemplate>, template: Option<RequestTemplate>,
http_port: u16, // Only used if input is HTTP server http_port: u16,
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
router_config: RouterConfig, router_config: RouterConfig,
runtime_config: ModelRuntimeConfig, runtime_config: ModelRuntimeConfig,
} }
...@@ -305,6 +325,14 @@ impl LocalModel { ...@@ -305,6 +325,14 @@ impl LocalModel {
self.http_port self.http_port
} }
pub fn tls_cert_path(&self) -> Option<&Path> {
self.tls_cert_path.as_deref()
}
pub fn tls_key_path(&self) -> Option<&Path> {
self.tls_key_path.as_deref()
}
pub fn router_config(&self) -> &RouterConfig { pub fn router_config(&self) -> &RouterConfig {
&self.router_config &self.router_config
} }
......
...@@ -283,10 +283,11 @@ impl Component { ...@@ -283,10 +283,11 @@ impl Component {
let component_clone = self.clone(); let component_clone = self.clone();
let mut hierarchies = self.parent_hierarchy(); let mut hierarchies = self.parent_hierarchy();
hierarchies.push(self.hierarchy()); hierarchies.push(self.hierarchy());
debug_assert_eq!( debug_assert!(hierarchies
hierarchies.last().cloned().unwrap_or_default(), .last()
self.service_name() .map(|x| x.as_str())
); // it happens that in component, hierarchy and service name are the same .unwrap_or_default()
.eq_ignore_ascii_case(&self.service_name())); // it happens that in component, hierarchy and service name are the same
// Start a background task that scrapes stats every 5 seconds // Start a background task that scrapes stats every 5 seconds
let m = component_metrics.clone(); let m = component_metrics.clone();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment