Unverified Commit dbf17a83 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Add mTLS Support for Router-to-Worker Communication (#12019)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 0f0c430e
......@@ -33,7 +33,7 @@ serde_json = { version = "1.0", default-features = false, features = [
] }
bytes = "1.8.0"
rand = "0.9.2"
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json", "rustls-tls"], default-features = false }
futures-util = "0.3"
futures = "0.3"
pyo3 = { version = "0.25.1", features = ["extension-module"] }
......
......@@ -106,6 +106,10 @@ class RouterArgs:
oracle_pool_min: int = 1
oracle_pool_max: int = 16
oracle_pool_timeout_secs: int = 30
# mTLS configuration for worker communication
client_cert_path: Optional[str] = None
client_key_path: Optional[str] = None
ca_cert_paths: List[str] = dataclasses.field(default_factory=list)
@staticmethod
def add_cli_args(
......@@ -575,6 +579,26 @@ class RouterArgs:
),
help="Oracle connection pool timeout in seconds (default: 30, env: ATP_POOL_TIMEOUT_SECS)",
)
# mTLS configuration
parser.add_argument(
f"--{prefix}client-cert-path",
type=str,
default=None,
help="Path to client certificate for mTLS authentication with workers",
)
parser.add_argument(
f"--{prefix}client-key-path",
type=str,
default=None,
help="Path to client private key for mTLS authentication with workers",
)
parser.add_argument(
f"--{prefix}ca-cert-paths",
type=str,
nargs="*",
default=[],
help="Path(s) to CA certificate(s) for verifying worker TLS certificates. Can specify multiple CAs.",
)
@classmethod
def from_cli_args(
......
"""
Generate self-signed certificates for mTLS integration testing.
Creates a Certificate Authority (CA), server certificates, and client certificates.
"""
import datetime
import ipaddress
from pathlib import Path
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
def generate_private_key():
"""Generate an RSA private key."""
return rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
def generate_ca_certificate():
"""Generate a self-signed CA certificate."""
private_key = generate_private_key()
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "SGLang Test"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Test"),
x509.NameAttribute(NameOID.COMMON_NAME, "Test CA"),
]
)
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=3650))
.add_extension(
x509.BasicConstraints(ca=True, path_length=None),
critical=True,
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_cert_sign=True,
crl_sign=True,
key_encipherment=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.sign(private_key, hashes.SHA256())
)
return private_key, cert
def generate_server_certificate(ca_key, ca_cert):
"""Generate a server certificate signed by the CA."""
private_key = generate_private_key()
subject = x509.Name(
[
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "SGLang Test"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Test"),
x509.NameAttribute(NameOID.COMMON_NAME, "localhost"),
]
)
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(ca_cert.subject)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
.add_extension(
x509.SubjectAlternativeName(
[
x509.DNSName("localhost"),
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
]
),
critical=False,
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_encipherment=True,
key_cert_sign=False,
crl_sign=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage(
[
x509.oid.ExtendedKeyUsageOID.SERVER_AUTH,
]
),
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
return private_key, cert
def generate_client_certificate(ca_key, ca_cert):
"""Generate a client certificate signed by the CA."""
private_key = generate_private_key()
subject = x509.Name(
[
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "SGLang Test"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Test"),
x509.NameAttribute(NameOID.COMMON_NAME, "test-client"),
]
)
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(ca_cert.subject)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_encipherment=True,
key_cert_sign=False,
crl_sign=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage(
[
x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH,
]
),
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
return private_key, cert
def save_key(key, path: Path):
"""Save private key to PEM file."""
with open(path, "wb") as f:
f.write(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
)
def save_cert(cert, path: Path):
"""Save certificate to PEM file."""
with open(path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
def generate_all_certificates(output_dir: Path):
"""Generate all certificates and keys for mTLS testing."""
output_dir.mkdir(parents=True, exist_ok=True)
print("==> Generating CA certificate...")
ca_key, ca_cert = generate_ca_certificate()
save_key(ca_key, output_dir / "ca-key.pem")
save_cert(ca_cert, output_dir / "ca-cert.pem")
print("==> Generating server certificate...")
server_key, server_cert = generate_server_certificate(ca_key, ca_cert)
save_key(server_key, output_dir / "server-key.pem")
save_cert(server_cert, output_dir / "server-cert.pem")
print("==> Generating client certificate...")
client_key, client_cert = generate_client_certificate(ca_key, ca_cert)
save_key(client_key, output_dir / "client-key.pem")
save_cert(client_cert, output_dir / "client-cert.pem")
print(f"==> Certificates generated successfully in {output_dir}")
print()
print("Files created:")
print(" - ca-cert.pem : CA certificate (for verifying server/client certs)")
print(" - ca-key.pem : CA private key")
print(" - server-cert.pem : Server certificate")
print(" - server-key.pem : Server private key")
print(" - client-cert.pem : Client certificate")
print(" - client-key.pem : Client private key")
print()
print("Test server can use: server-cert.pem + server-key.pem")
print("Test router can use: client-cert.pem + client-key.pem + ca-cert.pem")
if __name__ == "__main__":
script_dir = Path(__file__).parent
certs_dir = script_dir / "test_certs"
generate_all_certificates(certs_dir)
......@@ -47,6 +47,17 @@ def _parse_args() -> argparse.Namespace:
p.add_argument("--dp-size", type=int, default=1)
p.add_argument("--crash-on-request", action="store_true")
p.add_argument("--health-fail-after-ms", type=int, default=0)
# TLS/mTLS configuration
p.add_argument(
"--ssl-certfile", type=str, default=None, help="Path to SSL certificate file"
)
p.add_argument("--ssl-keyfile", type=str, default=None, help="Path to SSL key file")
p.add_argument(
"--ssl-ca-certs",
type=str,
default=None,
help="Path to CA certificates for client verification",
)
return p.parse_args()
......@@ -256,7 +267,18 @@ def main() -> None:
app = create_app(args)
# Handle SIGTERM gracefully for fast test teardown
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
uvicorn.run(app, host=args.host, port=args.port, log_level="warning")
# Configure SSL if certificates are provided
ssl_config = {}
if args.ssl_certfile and args.ssl_keyfile:
ssl_config["ssl_certfile"] = args.ssl_certfile
ssl_config["ssl_keyfile"] = args.ssl_keyfile
# If CA certs provided, require client certificates (mTLS)
if args.ssl_ca_certs:
ssl_config["ssl_ca_certs"] = args.ssl_ca_certs
ssl_config["ssl_cert_reqs"] = 2 # ssl.CERT_REQUIRED
uvicorn.run(app, host=args.host, port=args.port, log_level="warning", **ssl_config)
if __name__ == "__main__":
......
......@@ -101,6 +101,10 @@ class RouterManager:
"queue_size": "--queue-size",
"queue_timeout_secs": "--queue-timeout-secs",
"rate_limit_tokens_per_second": "--rate-limit-tokens-per-second",
# mTLS configuration
"client_cert_path": "--client-cert-path",
"client_key_path": "--client-key-path",
"ca_cert_paths": "--ca-cert-paths",
}
for k, v in extra.items():
if v is None:
......@@ -111,6 +115,11 @@ class RouterManager:
if isinstance(v, bool):
if v:
cmd.append(flag)
elif isinstance(v, list):
# Handle list arguments (e.g., ca_cert_paths)
if v: # Only add if list is not empty
cmd.append(flag)
cmd.extend([str(item) for item in v])
else:
cmd.extend([flag, str(v)])
......
import shutil
import subprocess
import time
from pathlib import Path
......@@ -6,6 +7,7 @@ from typing import Iterable, List, Optional, Tuple
import pytest
import requests
from ..fixtures.generate_test_certs import generate_all_certificates
from ..fixtures.ports import find_free_port
from ..fixtures.router_manager import RouterManager
......@@ -106,3 +108,21 @@ def mock_workers():
p.wait(timeout=3)
except subprocess.TimeoutExpired:
p.kill()
@pytest.fixture(scope="session")
def test_certificates():
"""Generate test certificates for mTLS tests, clean up after session."""
# Get the test_certs directory path
fixtures_dir = Path(__file__).parent.parent / "fixtures"
certs_dir = fixtures_dir / "test_certs"
# Generate certificates
generate_all_certificates(certs_dir)
# Yield the path to the certificates directory
yield certs_dir
# Cleanup: remove the generated certificates
if certs_dir.exists():
shutil.rmtree(certs_dir)
"""
Integration tests for mTLS (mutual TLS) authentication between router and workers.
Tests verify that:
1. Router can successfully connect to TLS-enabled workers with proper certificates
2. Router fails to connect to mTLS-required workers without client certificates
3. Router with CA certs can connect to TLS-only workers (server auth only)
"""
import subprocess
import time
from pathlib import Path
from typing import Tuple
import pytest
import requests
from ..fixtures.ports import find_free_port
def get_test_certs_dir() -> Path:
"""Get the path to the test certificates directory."""
return Path(__file__).parent.parent / "fixtures" / "test_certs"
def _spawn_tls_worker(
port: int,
worker_id: str,
ssl_certfile: str,
ssl_keyfile: str,
ssl_ca_certs: str = None,
) -> Tuple[subprocess.Popen, str]:
"""Spawn a mock worker with TLS/mTLS enabled."""
repo_root = Path(__file__).resolve().parents[2]
script = repo_root / "py_test" / "fixtures" / "mock_worker.py"
cmd = [
"python3",
str(script),
"--port",
str(port),
"--worker-id",
worker_id,
"--ssl-certfile",
ssl_certfile,
"--ssl-keyfile",
ssl_keyfile,
]
if ssl_ca_certs:
cmd.extend(["--ssl-ca-certs", ssl_ca_certs])
# Use DEVNULL for stdout to avoid blocking, but keep stderr for debugging
proc = subprocess.Popen(
cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, text=True
)
url = f"https://127.0.0.1:{port}"
# Give worker a moment to start or fail
import time
time.sleep(3) # Increased delay to ensure TLS server is fully initialized
# Check if process died immediately
if proc.poll() is not None:
_, stderr = proc.communicate()
raise RuntimeError(f"Worker failed to start.\nStderr: {stderr}")
# Wait for worker to be ready (with retries for SSL startup)
# For mTLS workers (with ssl_ca_certs), provide client cert for health check
certs_dir = get_test_certs_dir()
client_cert = certs_dir / "client-cert.pem" if ssl_ca_certs else None
client_key = certs_dir / "client-key.pem" if ssl_ca_certs else None
try:
_wait_tls_health(url, certs_dir / "ca-cert.pem", client_cert, client_key)
except TimeoutError:
# If health check times out, capture stderr for debugging
if proc.poll() is not None:
_, stderr = proc.communicate()
raise RuntimeError(f"Worker died during health check.\nStderr: {stderr}")
raise
return proc, url
def _wait_tls_health(
url: str,
ca_cert_path: Path = None,
client_cert_path: Path = None,
client_key_path: Path = None,
timeout: float = 10.0,
):
"""Wait for TLS-enabled worker to become healthy.
Args:
url: HTTPS URL of the worker
ca_cert_path: Path to CA certificate for verifying server cert
client_cert_path: Path to client certificate for mTLS
client_key_path: Path to client private key for mTLS
timeout: Maximum time to wait in seconds
"""
start = time.time()
last_error = None
with requests.Session() as s:
while time.time() - start < timeout:
try:
# Verify server cert with CA if provided, otherwise skip verification
verify = str(ca_cert_path) if ca_cert_path else False
# Provide client cert for mTLS if specified
cert = None
if client_cert_path and client_key_path:
cert = (str(client_cert_path), str(client_key_path))
r = s.get(f"{url}/health", timeout=1, verify=verify, cert=cert)
if r.status_code == 200:
return
except requests.RequestException as e:
# Save last error for debugging
last_error = e
time.sleep(0.2)
raise TimeoutError(
f"TLS worker at {url} did not become healthy. Last error: {last_error}"
)
@pytest.mark.integration
def test_mtls_successful_communication(router_manager, test_certificates):
"""Test that router can successfully communicate with mTLS-enabled worker."""
certs_dir = test_certificates
# Start worker with mTLS (requires client certificate)
port = find_free_port()
worker_id = f"tls-worker-{port}"
worker_proc, worker_url = _spawn_tls_worker(
port=port,
worker_id=worker_id,
ssl_certfile=str(certs_dir / "server-cert.pem"),
ssl_keyfile=str(certs_dir / "server-key.pem"),
ssl_ca_certs=str(certs_dir / "ca-cert.pem"), # Require client cert
)
try:
# Start router with mTLS configuration
rh = router_manager.start_router(
worker_urls=[worker_url],
policy="round_robin",
extra={
"client_cert_path": str(certs_dir / "client-cert.pem"),
"client_key_path": str(certs_dir / "client-key.pem"),
"ca_cert_paths": [str(certs_dir / "ca-cert.pem")],
},
)
# Make request through router - should succeed
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "hello",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200, f"Request failed: {r.status_code} {r.text}"
data = r.json()
assert "choices" in data
assert data.get("worker_id") == worker_id
finally:
if worker_proc.poll() is None:
worker_proc.terminate()
try:
worker_proc.wait(timeout=3)
except subprocess.TimeoutExpired:
worker_proc.kill()
@pytest.mark.integration
def test_mtls_failure_without_client_cert(router_manager, test_certificates):
"""Test that router fails to connect to mTLS worker without client certificates."""
certs_dir = test_certificates
# Start worker with mTLS (requires client certificate)
port = find_free_port()
worker_id = f"tls-worker-{port}"
worker_proc, worker_url = _spawn_tls_worker(
port=port,
worker_id=worker_id,
ssl_certfile=str(certs_dir / "server-cert.pem"),
ssl_keyfile=str(certs_dir / "server-key.pem"),
ssl_ca_certs=str(certs_dir / "ca-cert.pem"), # Require client cert
)
try:
# Start router WITHOUT client certificates (but with CA to verify server)
rh = router_manager.start_router(
worker_urls=[worker_url],
policy="round_robin",
extra={
"ca_cert_paths": [str(certs_dir / "ca-cert.pem")],
# Note: no client_cert_path or client_key_path
},
)
# Make request through router - should fail because worker requires client cert
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "hello",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
# Router should return 503 (service unavailable) or 500 because it can't connect to worker
assert r.status_code in [500, 503], f"Expected 500/503 but got {r.status_code}"
finally:
if worker_proc.poll() is None:
worker_proc.terminate()
try:
worker_proc.wait(timeout=3)
except subprocess.TimeoutExpired:
worker_proc.kill()
@pytest.mark.integration
def test_tls_server_auth_only(router_manager, test_certificates):
"""Test router can connect to TLS worker that doesn't require client certificates."""
certs_dir = test_certificates
# Start worker with TLS but WITHOUT requiring client certificates
port = find_free_port()
worker_id = f"tls-worker-{port}"
worker_proc, worker_url = _spawn_tls_worker(
port=port,
worker_id=worker_id,
ssl_certfile=str(certs_dir / "server-cert.pem"),
ssl_keyfile=str(certs_dir / "server-key.pem"),
ssl_ca_certs=None, # Don't require client cert
)
try:
# Start router with only CA cert (to verify server), no client cert
rh = router_manager.start_router(
worker_urls=[worker_url],
policy="round_robin",
extra={
"ca_cert_paths": [str(certs_dir / "ca-cert.pem")],
# Note: no client_cert_path or client_key_path needed
},
)
# Make request through router - should succeed with server-only TLS
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "hello",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200, f"Request failed: {r.status_code} {r.text}"
data = r.json()
assert "choices" in data
assert data.get("worker_id") == worker_id
finally:
if worker_proc.poll() is None:
worker_proc.terminate()
try:
worker_proc.wait(timeout=3)
except subprocess.TimeoutExpired:
worker_proc.kill()
@pytest.mark.integration
def test_tls_failure_without_ca_cert(router_manager, test_certificates):
"""Test that router fails to connect to TLS worker without CA certificate."""
certs_dir = test_certificates
# Start worker with TLS
port = find_free_port()
worker_id = f"tls-worker-{port}"
worker_proc, worker_url = _spawn_tls_worker(
port=port,
worker_id=worker_id,
ssl_certfile=str(certs_dir / "server-cert.pem"),
ssl_keyfile=str(certs_dir / "server-key.pem"),
ssl_ca_certs=None,
)
try:
# Start router WITHOUT CA certificate (can't verify server cert)
rh = router_manager.start_router(
worker_urls=[worker_url],
policy="round_robin",
extra={
# Note: no ca_cert_paths - router won't trust self-signed cert
},
)
# Make request through router - should fail because router can't verify server cert
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "hello",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
# Router should return 503 (service unavailable) or 500 because it can't verify worker cert
assert r.status_code in [500, 503], f"Expected 500/503 but got {r.status_code}"
finally:
if worker_proc.poll() is None:
worker_proc.terminate()
try:
worker_proc.wait(timeout=3)
except subprocess.TimeoutExpired:
worker_proc.kill()
use super::{
CircuitBreakerConfig, ConfigResult, DiscoveryConfig, HealthCheckConfig, HistoryBackend,
MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
TokenizerCacheConfig,
CircuitBreakerConfig, ConfigError, ConfigResult, DiscoveryConfig, HealthCheckConfig,
HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, RouterConfig,
RoutingMode, TokenizerCacheConfig,
};
use crate::core::ConnectionMode;
......@@ -10,6 +10,10 @@ use crate::core::ConnectionMode;
#[derive(Debug, Clone, Default)]
pub struct RouterConfigBuilder {
config: RouterConfig,
// Temporary fields for certificate paths (read during build)
client_cert_path: Option<String>,
client_key_path: Option<String>,
ca_cert_paths: Vec<String>,
}
impl RouterConfigBuilder {
......@@ -20,7 +24,12 @@ impl RouterConfigBuilder {
/// Create a builder from an existing configuration (takes ownership)
pub fn from_config(config: RouterConfig) -> Self {
Self { config }
Self {
config,
client_cert_path: None,
client_key_path: None,
ca_cert_paths: Vec::new(),
}
}
/// Create a builder from a reference to an existing configuration
......@@ -569,6 +578,48 @@ impl RouterConfigBuilder {
self
}
// ==================== mTLS Configuration ====================
/// Set client certificate and key paths for mTLS authentication
/// Both paths must be provided together
/// Files will be read during build()
pub fn client_cert_and_key<S1: Into<String>, S2: Into<String>>(
mut self,
cert_path: S1,
key_path: S2,
) -> Self {
self.client_cert_path = Some(cert_path.into());
self.client_key_path = Some(key_path.into());
self
}
/// Set client certificate and key paths for mTLS if both paths are provided
/// Files will be read during build()
pub fn maybe_client_cert_and_key(
mut self,
cert_path: Option<impl Into<String>>,
key_path: Option<impl Into<String>>,
) -> Self {
self.client_cert_path = cert_path.map(|p| p.into());
self.client_key_path = key_path.map(|p| p.into());
self
}
/// Add a CA certificate path for verifying worker TLS certificates
/// File will be read during build()
pub fn add_ca_certificate<S: Into<String>>(mut self, ca_cert_path: S) -> Self {
self.ca_cert_paths.push(ca_cert_path.into());
self
}
/// Add multiple CA certificate paths for verifying worker TLS certificates
/// Files will be read during build()
pub fn add_ca_certificates<S: Into<String>>(mut self, ca_cert_paths: Vec<S>) -> Self {
self.ca_cert_paths
.extend(ca_cert_paths.into_iter().map(|p| p.into()));
self
}
// ==================== Builder Methods ====================
/// Build the RouterConfig, validating if requested
......@@ -582,13 +633,68 @@ impl RouterConfigBuilder {
}
/// Build with optional validation
pub fn build_with_validation(self, validate: bool) -> ConfigResult<RouterConfig> {
pub fn build_with_validation(mut self, validate: bool) -> ConfigResult<RouterConfig> {
// Read mTLS certificates from paths if provided
self = self.read_mtls_certificates()?;
let config: RouterConfig = self.into();
if validate {
config.validate()?;
}
Ok(config)
}
/// Internal method to read mTLS certificates from paths
fn read_mtls_certificates(mut self) -> ConfigResult<Self> {
// Read client certificate and key
match (&self.client_cert_path, &self.client_key_path) {
(Some(cert_path), Some(key_path)) => {
let cert = std::fs::read(cert_path).map_err(|e| ConfigError::ValidationFailed {
reason: format!(
"Failed to read client certificate from {}: {}",
cert_path, e
),
})?;
let key = std::fs::read(key_path).map_err(|e| ConfigError::ValidationFailed {
reason: format!("Failed to read client key from {}: {}", key_path, e),
})?;
// Combine cert and key into single PEM for reqwest::Identity
// When using rustls, certificate must come first, then key
// Ensure proper PEM formatting with newlines
let mut combined = cert;
if !combined.ends_with(b"\n") {
combined.push(b'\n');
}
combined.extend_from_slice(&key);
if !combined.ends_with(b"\n") {
combined.push(b'\n');
}
self.config.client_identity = Some(combined);
}
(None, None) => {
// No client cert configured, that's fine
}
_ => {
return Err(ConfigError::ValidationFailed {
reason:
"Both --client-cert-path and --client-key-path must be specified together"
.to_string(),
});
}
}
// Read CA certificates
for path in &self.ca_cert_paths {
let cert = std::fs::read(path).map_err(|e| ConfigError::ValidationFailed {
reason: format!("Failed to read CA certificate from {}: {}", path, e),
})?;
self.config.ca_certificates.push(cert);
}
Ok(self)
}
}
impl From<RouterConfigBuilder> for RouterConfig {
......
......@@ -85,6 +85,14 @@ pub struct RouterConfig {
/// Tokenizer cache configuration
#[serde(default)]
pub tokenizer_cache: TokenizerCacheConfig,
/// mTLS client identity (combined certificate + key in PEM format)
/// This is loaded from client_cert_path and client_key_path during config creation
#[serde(skip)]
pub client_identity: Option<Vec<u8>>,
/// CA certificates for verifying worker TLS certificates (PEM format)
/// Loaded from ca_cert_paths during config creation
#[serde(default)]
pub ca_certificates: Vec<Vec<u8>>,
}
/// Tokenizer cache configuration
......@@ -498,6 +506,8 @@ impl Default for RouterConfig {
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
client_identity: None,
ca_certificates: vec![],
}
}
}
......
......@@ -469,6 +469,29 @@ impl ConfigValidator {
Ok(())
}
/// Validate mTLS certificate configuration
fn validate_mtls(config: &RouterConfig) -> ConfigResult<()> {
// Validate that if we have client_identity, it's not empty
if let Some(identity) = &config.client_identity {
if identity.is_empty() {
return Err(ConfigError::ValidationFailed {
reason: "Client identity cannot be empty".to_string(),
});
}
}
// Validate CA certificates are not empty
for (idx, ca_cert) in config.ca_certificates.iter().enumerate() {
if ca_cert.is_empty() {
return Err(ConfigError::ValidationFailed {
reason: format!("CA certificate at index {} cannot be empty", idx),
});
}
}
Ok(())
}
/// Validate compatibility between different configuration sections
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
// IGW mode is independent - skip other compatibility checks when enabled
......@@ -486,6 +509,9 @@ impl ConfigValidator {
});
}
// Validate mTLS configuration
Self::validate_mtls(config)?;
// All policies are now supported for both router types thanks to the unified trait design
// No mode/policy restrictions needed anymore
......
......@@ -123,16 +123,29 @@ fn strip_protocol(url: &str) -> String {
}
/// Helper: Try HTTP health check
async fn try_http_health_check(url: &str, timeout_secs: u64) -> Result<(), String> {
///
/// Uses the provided client (from app_context) which supports both HTTP and HTTPS.
/// For HTTPS URLs, the client's TLS configuration (mTLS, CA certs) is used.
/// For plain HTTP URLs, the client handles them normally without TLS overhead.
async fn try_http_health_check(
url: &str,
timeout_secs: u64,
client: &Client,
) -> Result<(), String> {
// Preserve the protocol (http or https) from the original URL
let is_https = url.starts_with("https://");
let protocol = if is_https { "https" } else { "http" };
let clean_url = strip_protocol(url);
let health_url = format!("http://{}/health", clean_url);
let health_url = format!("{}://{}/health", protocol, clean_url);
HTTP_CLIENT
// Use the AppContext client for both HTTP and HTTPS
// The rustls backend handles both protocols correctly
client
.get(&health_url)
.timeout(Duration::from_secs(timeout_secs))
.send()
.await
.map_err(|e| format!("HTTP health check failed: {}", e))?;
.map_err(|e| format!("Health check failed: {}", e))?;
Ok(())
}
......@@ -235,6 +248,9 @@ impl StepExecutor for DetectConnectionModeStep {
let config: Arc<WorkerConfigRequest> = context
.get("worker_config")
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_config".to_string()))?;
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
debug!(
"Detecting connection mode for {} (timeout: {}s, max_attempts: {})",
......@@ -242,10 +258,12 @@ impl StepExecutor for DetectConnectionModeStep {
);
// Try both protocols in parallel using configured timeout
// Use the AppContext client which has TLS configuration (CA certs, client identity)
let url = config.url.clone();
let timeout = config.health_check_timeout_secs;
let client = &app_context.client;
let (http_result, grpc_result) = tokio::join!(
try_http_health_check(&url, timeout),
try_http_health_check(&url, timeout, client),
try_grpc_health_check(&url, timeout)
);
......
......@@ -207,6 +207,9 @@ struct Router {
backend: BackendType,
history_backend: HistoryBackendType,
oracle_config: Option<PyOracleConfig>,
client_cert_path: Option<String>,
client_key_path: Option<String>,
ca_cert_paths: Vec<String>,
}
impl Router {
......@@ -302,7 +305,7 @@ impl Router {
None
};
let builder = config::RouterConfig::builder()
config::RouterConfig::builder()
.mode(mode)
.policy(policy)
.host(&self.host)
......@@ -359,9 +362,13 @@ impl Router {
.dp_aware(self.dp_aware)
.retries(!self.disable_retries)
.circuit_breaker(!self.disable_circuit_breaker)
.igw(self.enable_igw);
builder.build()
.igw(self.enable_igw)
.maybe_client_cert_and_key(
self.client_cert_path.as_ref(),
self.client_key_path.as_ref(),
)
.add_ca_certificates(self.ca_cert_paths.clone())
.build()
}
}
......@@ -435,6 +442,9 @@ impl Router {
backend = BackendType::Sglang,
history_backend = HistoryBackendType::Memory,
oracle_config = None,
client_cert_path = None,
client_key_path = None,
ca_cert_paths = vec![],
))]
#[allow(clippy::too_many_arguments)]
fn new(
......@@ -504,6 +514,9 @@ impl Router {
backend: BackendType,
history_backend: HistoryBackendType,
oracle_config: Option<PyOracleConfig>,
client_cert_path: Option<String>,
client_key_path: Option<String>,
ca_cert_paths: Vec<String>,
) -> PyResult<Self> {
let mut all_urls = worker_urls.clone();
......@@ -587,6 +600,9 @@ impl Router {
backend,
history_backend,
oracle_config,
client_cert_path,
client_key_path,
ca_cert_paths,
})
}
......
......@@ -799,13 +799,59 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config.max_payload_size / (1024 * 1024)
);
let client = Client::builder()
// FIXME: Current implementation creates a single HTTP client for all workers.
// This works well for single security domain deployments where all workers share
// the same CA and can accept the same client certificate.
//
// For multi-domain deployments (e.g., different model families with different CAs),
// this architecture needs significant refactoring:
// 1. Move client creation into worker registration workflow (per-worker clients)
// 2. Store client per worker in WorkerRegistry
// 3. Update PDRouter and other routers to fetch client from worker
// 4. Add per-worker TLS spec in WorkerConfigRequest
//
// Current single-domain approach is sufficient for most deployments.
//
// Use rustls TLS backend when TLS/mTLS is configured (client cert or CA certs provided).
// This ensures proper PKCS#8 key format support. For plain HTTP workers, use default
// backend to avoid unnecessary TLS initialization overhead.
let has_tls_config = config.router_config.client_identity.is_some()
|| !config.router_config.ca_certificates.is_empty();
let mut client_builder = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50)))
.pool_max_idle_per_host(500)
.timeout(Duration::from_secs(config.request_timeout_secs))
.connect_timeout(Duration::from_secs(10))
.tcp_nodelay(true)
.tcp_keepalive(Some(Duration::from_secs(30)))
.tcp_keepalive(Some(Duration::from_secs(30)));
// Force rustls backend when TLS is configured
if has_tls_config {
client_builder = client_builder.use_rustls_tls();
info!("Using rustls TLS backend for TLS/mTLS connections");
}
// Configure mTLS client identity if provided (certificates already loaded during config creation)
if let Some(identity_pem) = &config.router_config.client_identity {
let identity = reqwest::Identity::from_pem(identity_pem)?;
client_builder = client_builder.identity(identity);
info!("mTLS client authentication enabled");
}
// Add CA certificates for verifying worker TLS (certificates already loaded during config creation)
for ca_cert in &config.router_config.ca_certificates {
let cert = reqwest::Certificate::from_pem(ca_cert)?;
client_builder = client_builder.add_root_certificate(cert);
}
if !config.router_config.ca_certificates.is_empty() {
info!(
"Added {} CA certificate(s) for worker verification",
config.router_config.ca_certificates.len()
);
}
let client = client_builder
.build()
.expect("Failed to create HTTP client");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment