Unverified Commit 18b64e90 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: switch to tcp as default request plane (#4845)

parent b401a1d6
...@@ -270,7 +270,7 @@ def parse_args(): ...@@ -270,7 +270,7 @@ def parse_args():
"--request-plane", "--request-plane",
type=str, type=str,
choices=["nats", "http", "tcp"], choices=["nats", "http", "tcp"],
default=os.environ.get("DYN_REQUEST_PLANE", "nats"), default=os.environ.get("DYN_REQUEST_PLANE", "tcp"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
) )
parser.add_argument( parser.add_argument(
......
...@@ -302,7 +302,7 @@ def parse_args(): ...@@ -302,7 +302,7 @@ def parse_args():
"--request-plane", "--request-plane",
type=str, type=str,
choices=["nats", "http", "tcp"], choices=["nats", "http", "tcp"],
default=os.environ.get("DYN_REQUEST_PLANE", "nats"), default=os.environ.get("DYN_REQUEST_PLANE", "tcp"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
) )
......
...@@ -113,7 +113,7 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = { ...@@ -113,7 +113,7 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"flags": ["--request-plane"], "flags": ["--request-plane"],
"type": str, "type": str,
"choices": ["nats", "http", "tcp"], "choices": ["nats", "http", "tcp"],
"default": os.environ.get("DYN_REQUEST_PLANE", "nats"), "default": os.environ.get("DYN_REQUEST_PLANE", "tcp"),
"help": "Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", "help": "Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
}, },
"enable-local-indexer": { "enable-local-indexer": {
......
...@@ -302,7 +302,7 @@ def cmd_line_args(): ...@@ -302,7 +302,7 @@ def cmd_line_args():
"--request-plane", "--request-plane",
type=str, type=str,
choices=["nats", "http", "tcp"], choices=["nats", "http", "tcp"],
default=os.environ.get("DYN_REQUEST_PLANE", "nats"), default=os.environ.get("DYN_REQUEST_PLANE", "tcp"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
) )
parser.add_argument( parser.add_argument(
......
...@@ -202,7 +202,7 @@ def parse_args() -> Config: ...@@ -202,7 +202,7 @@ def parse_args() -> Config:
"--request-plane", "--request-plane",
type=str, type=str,
choices=["nats", "http", "tcp"], choices=["nats", "http", "tcp"],
default=os.environ.get("DYN_REQUEST_PLANE", "nats"), default=os.environ.get("DYN_REQUEST_PLANE", "tcp"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
) )
parser.add_argument( parser.add_argument(
......
...@@ -21,8 +21,8 @@ limitations under the License. ...@@ -21,8 +21,8 @@ limitations under the License.
Dynamo supports multiple transport mechanisms for its request plane (the communication layer between services). You can choose from three different request plane modes based on your deployment requirements: Dynamo supports multiple transport mechanisms for its request plane (the communication layer between services). You can choose from three different request plane modes based on your deployment requirements:
- **NATS** (default): Message broker-based request plane - **TCP** (default): Direct TCP connection for optimal performance
- **TCP**: Direct TCP connection for optimal performance - **NATS**: Message broker-based request plane
- **HTTP**: HTTP/2-based request plane - **HTTP**: HTTP/2-based request plane
This guide explains how to configure and use request plane in your Dynamo deployment. This guide explains how to configure and use request plane in your Dynamo deployment.
...@@ -37,16 +37,21 @@ The request plane is the transport layer that handles communication between Dyna ...@@ -37,16 +37,21 @@ The request plane is the transport layer that handles communication between Dyna
| **TCP** | Low-latency direct communication | Direct connections, minimal overhead | | **TCP** | Low-latency direct communication | Direct connections, minimal overhead |
| **HTTP** | Standard deployments, debugging | HTTP/2 protocol, easier observability with standard tools, widely compatible | | **HTTP** | Standard deployments, debugging | HTTP/2 protocol, easier observability with standard tools, widely compatible |
## KV Routing and NATS ## Request Plane vs KV Event Plane
Dynamo's Key-Value (KV) cache based routing optimizes large language model inference by intelligently directing requests to workers with the most relevant KV cache data. KV-aware routing improves both Time To First Token (TTFT) through better cache locality and Inter-Token Latency (ITL) through intelligent load balancing. Dynamo has **two independent communication planes**:
Please refer to the [KV Cache Routing documentation](../router/kv_cache_routing.md) for more details. - **Request plane** (**`DYN_REQUEST_PLANE`**): how **RPC requests** flow between components (frontend → router → worker), via `tcp`, `http`, or `nats`.
- **KV event plane** (currently only **NATS** is supported): how **KV cache events** (and optional router replica sync) are distributed/persisted for KV-aware routing.
There are two modes of KV based routing: **Note:** if you are using `tcp` or `http` request plane and choose to use NATS for KV events, you must still configure NATS server using `NATS_SERVER` environment variable, e.g. `NATS_SERVER=nats://nats-hostname:port`.
- Exact KV routing (needs NATS): KV routing is based KV events indexing in a radix tree scoring the best match for the request. *This requires NATS* to persist and distribute KV events across routers.
- Approximate KV routing (does not need NATS): KV routing is based on approximate load heuristics. *This does not require NATS*. Because they are independent, you can mix them.
For example, a deployment with TCP request plane can use different KV event planes:
- **JetStream KV events**: requests use TCP, KV routing still uses NATS JetStream + object store for persistence.
- **NATS Core KV events (local indexer)**: requests use TCP, KV events use NATS Core pub/sub and persistence lives on workers.
- **no KV events**: requests use TCP and KV routing runs without events (no NATS required, but no event-backed persistence).
## Configuration ## Configuration
...@@ -59,51 +64,27 @@ export DYN_REQUEST_PLANE=<mode> ...@@ -59,51 +64,27 @@ export DYN_REQUEST_PLANE=<mode>
``` ```
Where `<mode>` is one of: Where `<mode>` is one of:
- `nats` (default) - `tcp` (default)
- `tcp` - `nats`
- `http` - `http`
The value is case-insensitive. The value is case-insensitive.
### Default Behavior ### Default Behavior
If `DYN_REQUEST_PLANE` is not set or contains an invalid value, Dynamo defaults to `nats`. If `DYN_REQUEST_PLANE` is not set or contains an invalid value, Dynamo defaults to `tcp`.
## Usage Examples ## Usage Examples
### Using NATS (Default) ### Using TCP (Default)
NATS is the default request plane and provides the most flexibility for complex deployments.
**Prerequisites:**
- NATS server must be running and accessible
- Configure NATS connection via standard Dynamo NATS environment variables
```bash
# Explicitly set to NATS (optional, as it's the default)
# Run your Dynamo service
DYN_REQUEST_PLANE=nats python -m dynamo.frontend --http-port=8000 &
DYN_REQUEST_PLANE=nats python -m dynamo.vllm --model Qwen/Qwen3-0.6B
```
**When to use NATS:**
- Production deployments with service discovery
- Currently (HA) highly available routers require durable messages persisted in NATS message broker. If you want to completely disable NATS, KV based routing won't be available
- Multiple frontends and backends
- Need for message replay and persistence features
Limitations:
- NATS does not support payloads beyond 16MB (use TCP for larger payloads)
### Using TCP
TCP provides direct, low-latency communication between services. TCP is the default request plane and provides direct, low-latency communication between services.
**Configuration:** **Configuration:**
```bash ```bash
# Set request plane to TCP # TCP is the default, so no need to set DYN_REQUEST_PLANE explicitly
# But you can explicitly set it if desired:
export DYN_REQUEST_PLANE=tcp export DYN_REQUEST_PLANE=tcp
# Optional: Configure TCP server host and port # Optional: Configure TCP server host and port
...@@ -119,7 +100,7 @@ DYN_REQUEST_PLANE=tcp python -m dynamo.vllm --model Qwen/Qwen3-0.6B ...@@ -119,7 +100,7 @@ DYN_REQUEST_PLANE=tcp python -m dynamo.vllm --model Qwen/Qwen3-0.6B
**When to use TCP:** **When to use TCP:**
- Simple deployments with direct service-to-service communication (e.g. frontend to backend) - Simple deployments with direct service-to-service communication (e.g. frontend to backend)
- Minimal infrastructure requirements (no NATS needed) - Minimal infrastructure requirements (**no NATS needed unless you enable KV-event-backed routing/replica sync**)
- Low-latency requirements - Low-latency requirements
**TCP Configuration Options:** **TCP Configuration Options:**
...@@ -172,6 +153,31 @@ Additional HTTP-specific environment variables: ...@@ -172,6 +153,31 @@ Additional HTTP-specific environment variables:
- `DYN_HTTP2_KEEP_ALIVE_TIMEOUT_SECS`: Keep-alive timeout for HTTP client (default: 10 seconds) - `DYN_HTTP2_KEEP_ALIVE_TIMEOUT_SECS`: Keep-alive timeout for HTTP client (default: 10 seconds)
- `DYN_HTTP2_ADAPTIVE_WINDOW`: Enable adaptive flow control (default: true) - `DYN_HTTP2_ADAPTIVE_WINDOW`: Enable adaptive flow control (default: true)
### Using NATS
NATS provides durable jetstream messaging for request plane and can be used for KV events (and router replica sync).
**Prerequisites:**
- NATS server must be running and accessible
- Configure NATS connection via standard Dynamo NATS environment variables
```bash
# Explicitly set to NATS
export DYN_REQUEST_PLANE=nats
# Run your Dynamo service
DYN_REQUEST_PLANE=nats python -m dynamo.frontend --http-port=8000 &
DYN_REQUEST_PLANE=nats python -m dynamo.vllm --model Qwen/Qwen3-0.6B
```
**When to use NATS:**
- Production deployments with service discovery
- Currently KV based routing require NATS. If you want to completely disable NATS, KV based routing won't be available
- Need for message replay and persistence features
Limitations:
- NATS does not support payloads beyond 16MB (use TCP for larger payloads)
## Complete Example ## Complete Example
Here's a complete example showing how to launch a Dynamo deployment with different request planes: Here's a complete example showing how to launch a Dynamo deployment with different request planes:
...@@ -221,7 +227,7 @@ This abstraction means your application code doesn't need to change when switchi ...@@ -221,7 +227,7 @@ This abstraction means your application code doesn't need to change when switchi
Request plane configuration is loaded from environment variables at startup and cached globally. The configuration hierarchy is: Request plane configuration is loaded from environment variables at startup and cached globally. The configuration hierarchy is:
1. **Mode Selection**: `DYN_REQUEST_PLANE` (defaults to `nats`) 1. **Mode Selection**: `DYN_REQUEST_PLANE` (defaults to `tcp`)
2. **Transport-Specific Config**: Mode-specific environment variables (e.g., `DYN_TCP_*`, `DYN_HTTP2_*`) 2. **Transport-Specific Config**: Mode-specific environment variables (e.g., `DYN_TCP_*`, `DYN_HTTP2_*`)
## Migration Guide ## Migration Guide
...@@ -274,7 +280,7 @@ curl http://localhost:8000/v1/chat/completions \ ...@@ -274,7 +280,7 @@ curl http://localhost:8000/v1/chat/completions \
**Solutions:** **Solutions:**
- Check `DYN_REQUEST_PLANE` spelling (valid values: `nats`, `tcp`, `http`) - Check `DYN_REQUEST_PLANE` spelling (valid values: `nats`, `tcp`, `http`)
- Value is case-insensitive but must be one of the three options - Value is case-insensitive but must be one of the three options
- If not set, defaults to `nats` - If not set, defaults to `tcp`
### Issue: Port Conflicts ### Issue: Port Conflicts
......
...@@ -47,6 +47,11 @@ The main KV-aware routing arguments: ...@@ -47,6 +47,11 @@ The main KV-aware routing arguments:
> - **NATS Core with Local Indexer mode** (`--enable-local-indexer` on workers): State persists on workers—router rebuilds state by querying workers on startup. > - **NATS Core with Local Indexer mode** (`--enable-local-indexer` on workers): State persists on workers—router rebuilds state by querying workers on startup.
> - **No KV events** (`--no-kv-events`): State persistence is not supported. > - **No KV events** (`--no-kv-events`): State persistence is not supported.
> >
> **Request plane is independent of KV event transport.**
> `DYN_REQUEST_PLANE` controls how **requests** are sent (TCP/HTTP/NATS), but KV-aware routing still uses **NATS** for KV events in both JetStream and NATS Core + Local Indexer modes.
> If you run with `DYN_REQUEST_PLANE=tcp` (or `http`) and KV events enabled (default), you must also configure NATS, e.g. `NATS_SERVER=nats://...`.
> Only `--no-kv-events` removes the NATS requirement.
>
> When `--kv-overlap-score-weight` is set to 0 or `--no-kv-events` is set, no KvIndexer will be launched to drain and process KV events. It's recommended to disable your backend workers from relaying events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases. > When `--kv-overlap-score-weight` is set to 0 or `--no-kv-events` is set, no KvIndexer will be launched to drain and process KV events. It's recommended to disable your backend workers from relaying events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases.
> >
> The cli args `--router-ttl`, `--router-max-tree-size`, and `--router-prune-target-ratio` control local cache management when the router operates without receiving events from workers. When KV events are enabled (default), the router relies on worker-side eviction events and these parameters are ignored. > The cli args `--router-ttl`, `--router-max-tree-size`, and `--router-prune-target-ratio` control local cache management when the router operates without receiving events from workers. When KV events are enabled (default), the router relies on worker-side eviction events and these parameters are ignored.
......
...@@ -127,7 +127,7 @@ pub struct Flags { ...@@ -127,7 +127,7 @@ pub struct Flags {
pub store_kv: String, pub store_kv: String,
/// Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]. /// Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp].
#[arg(long, default_value = "nats", value_parser = ["nats", "http", "tcp"])] #[arg(long, default_value = "tcp", value_parser = ["nats", "http", "tcp"])]
pub request_plane: String, pub request_plane: String,
/// Everything after a `--`. Not currently used. /// Everything after a `--`. Not currently used.
......
...@@ -573,8 +573,15 @@ impl DistributedRuntime { ...@@ -573,8 +573,15 @@ impl DistributedRuntime {
let runtime_config = DistributedConfig { let runtime_config = DistributedConfig {
store_backend: selected_kv_store, store_backend: selected_kv_store,
// We only need NATS here to monitor it's metrics, so only if it's our request plane. // NATS is used for more than just the NATS request-plane:
nats_config: if request_plane.is_nats() { // - KV router events (JetStream or NATS core + local indexer)
// - inter-router replica sync (NATS core)
//
// If a NATS server is configured via env, enable the client regardless of request plane.
nats_config: if request_plane.is_nats()
|| std::env::var(dynamo_runtime::config::environment_names::nats::NATS_SERVER)
.is_ok()
{
Some(dynamo_runtime::transports::nats::ClientOptions::default()) Some(dynamo_runtime::transports::nats::ClientOptions::default())
} else { } else {
None None
......
...@@ -26,7 +26,7 @@ def dynamo_worker(): ...@@ -26,7 +26,7 @@ def dynamo_worker():
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
request_plane = os.environ.get("DYN_REQUEST_PLANE", "nats") request_plane = os.environ.get("DYN_REQUEST_PLANE", "tcp")
runtime = DistributedRuntime(loop, "etcd", request_plane) runtime = DistributedRuntime(loop, "etcd", request_plane)
await func(runtime, *args, **kwargs) await func(runtime, *args, **kwargs)
......
...@@ -530,9 +530,18 @@ pub struct DistributedConfig { ...@@ -530,9 +530,18 @@ pub struct DistributedConfig {
impl DistributedConfig { impl DistributedConfig {
pub fn from_settings() -> DistributedConfig { pub fn from_settings() -> DistributedConfig {
let request_plane = RequestPlaneMode::from_env(); let request_plane = RequestPlaneMode::from_env();
// NATS is used for more than just NATS request-plane RPC:
// - KV router events (JetStream or NATS core + local indexer)
// - inter-router replica sync (NATS core)
//
// Historically we only connected to NATS when the request plane was NATS, which made
// `DYN_REQUEST_PLANE=tcp|http` incompatible with KV routing modes that rely on NATS.
// If a NATS server is configured via env, enable the client regardless of request plane.
let nats_enabled = request_plane.is_nats()
|| std::env::var(crate::config::environment_names::nats::NATS_SERVER).is_ok();
DistributedConfig { DistributedConfig {
store_backend: kv::Selector::Etcd(Box::default()), store_backend: kv::Selector::Etcd(Box::default()),
nats_config: if request_plane.is_nats() { nats_config: if nats_enabled {
Some(nats::ClientOptions::default()) Some(nats::ClientOptions::default())
} else { } else {
None None
...@@ -547,9 +556,11 @@ impl DistributedConfig { ...@@ -547,9 +556,11 @@ impl DistributedConfig {
..Default::default() ..Default::default()
}; };
let request_plane = RequestPlaneMode::from_env(); let request_plane = RequestPlaneMode::from_env();
let nats_enabled = request_plane.is_nats()
|| std::env::var(crate::config::environment_names::nats::NATS_SERVER).is_ok();
DistributedConfig { DistributedConfig {
store_backend: kv::Selector::Etcd(Box::new(etcd_config)), store_backend: kv::Selector::Etcd(Box::new(etcd_config)),
nats_config: if request_plane.is_nats() { nats_config: if nats_enabled {
Some(nats::ClientOptions::default()) Some(nats::ClientOptions::default())
} else { } else {
None None
...@@ -574,12 +585,12 @@ impl DistributedConfig { ...@@ -574,12 +585,12 @@ impl DistributedConfig {
/// Request plane transport mode configuration /// Request plane transport mode configuration
/// ///
/// This determines how requests are distributed from routers to workers: /// This determines how requests are distributed from routers to workers:
/// - `Nats`: Use NATS for request distribution (default, legacy) /// - `Nats`: Use NATS for request distribution (legacy)
/// - `Http`: Use HTTP/2 for request distribution /// - `Http`: Use HTTP/2 for request distribution
/// - `Tcp`: Use raw TCP for request distribution with msgpack support /// - `Tcp`: Use raw TCP for request distribution with msgpack support (default)
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RequestPlaneMode { pub enum RequestPlaneMode {
/// Use NATS for request plane (default for backward compatibility) /// Use NATS for request plane
Nats, Nats,
/// Use HTTP/2 for request plane /// Use HTTP/2 for request plane
Http, Http,
...@@ -589,7 +600,7 @@ pub enum RequestPlaneMode { ...@@ -589,7 +600,7 @@ pub enum RequestPlaneMode {
impl Default for RequestPlaneMode { impl Default for RequestPlaneMode {
fn default() -> Self { fn default() -> Self {
Self::Nats Self::Tcp
} }
} }
......
...@@ -562,14 +562,16 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane): ...@@ -562,14 +562,16 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
leak across workers. leak across workers.
- If store_kv != "etcd", etcd is not started (returns None) - If store_kv != "etcd", etcd is not started (returns None)
- If request_plane != "nats", NATS is not started (returns None) - NATS is always started when etcd is used, because KV events require NATS
regardless of the request_plane (tcp/nats only affects request transport)
Returns a tuple of (nats_process, etcd_process) where each has a .port attribute. Returns a tuple of (nats_process, etcd_process) where each has a .port attribute.
""" """
import os import os
# Port cleanup is now handled in NatsServer and EtcdServer __exit__ methods # Port cleanup is now handled in NatsServer and EtcdServer __exit__ methods
if request_plane == "nats" and store_kv == "etcd": # Always start NATS when etcd is used - KV events require NATS regardless of request_plane
if store_kv == "etcd":
with NatsServer(request, port=0) as nats_process: with NatsServer(request, port=0) as nats_process:
with EtcdServer(request, port=0) as etcd_process: with EtcdServer(request, port=0) as etcd_process:
# Set environment variables for Rust/Python runtime to use. Note that xdist (parallel execution) # Set environment variables for Rust/Python runtime to use. Note that xdist (parallel execution)
...@@ -587,11 +589,6 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane): ...@@ -587,11 +589,6 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
os.environ["NATS_SERVER"] = f"nats://localhost:{nats_process.port}" os.environ["NATS_SERVER"] = f"nats://localhost:{nats_process.port}"
yield nats_process, None yield nats_process, None
os.environ.pop("NATS_SERVER", None) os.environ.pop("NATS_SERVER", None)
elif store_kv == "etcd":
with EtcdServer(request, port=0) as etcd_process:
os.environ["ETCD_ENDPOINTS"] = f"http://localhost:{etcd_process.port}"
yield None, etcd_process
os.environ.pop("ETCD_ENDPOINTS", None)
else: else:
yield None, None yield None, None
......
...@@ -367,12 +367,12 @@ async def send_request_with_retry(url: str, payload: dict, max_retries: int = 8) ...@@ -367,12 +367,12 @@ async def send_request_with_retry(url: str, payload: dict, max_retries: int = 8)
return False return False
def get_runtime(store_backend="etcd", request_plane="nats"): def get_runtime(store_backend="etcd", request_plane="tcp"):
"""Create a DistributedRuntime instance for testing. """Create a DistributedRuntime instance for testing.
Args: Args:
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd". store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
request_plane: How frontend talks to backend ("tcp", "http" or "nats"). Defaults to "nats". request_plane: How frontend talks to backend ("tcp", "http" or "nats"). Defaults to "tcp".
""" """
try: try:
# Try to get running loop (works in async context) # Try to get running loop (works in async context)
...@@ -619,6 +619,7 @@ def _test_router_basic( ...@@ -619,6 +619,7 @@ def _test_router_basic(
num_requests: int, num_requests: int,
frontend_timeout: int = 120, frontend_timeout: int = 120,
store_backend: str = "etcd", store_backend: str = "etcd",
request_plane: str = "nats",
): ):
"""Basic router test: start router, wait for workers and send concurrent requests via HTTP frontend. """Basic router test: start router, wait for workers and send concurrent requests via HTTP frontend.
...@@ -636,6 +637,7 @@ def _test_router_basic( ...@@ -636,6 +637,7 @@ def _test_router_basic(
num_requests: Number of concurrent requests to send num_requests: Number of concurrent requests to send
frontend_timeout: Timeout for frontend readiness check (default: 120s) frontend_timeout: Timeout for frontend readiness check (default: 120s)
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd". store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "nats".
Raises: Raises:
AssertionError: If requests fail or frontend doesn't become ready AssertionError: If requests fail or frontend doesn't become ready
...@@ -645,7 +647,12 @@ def _test_router_basic( ...@@ -645,7 +647,12 @@ def _test_router_basic(
# Start KV router frontend # Start KV router frontend
logger.info(f"Starting KV router frontend on port {frontend_port}") logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess( kv_router = KVRouterProcess(
request, block_size, frontend_port, engine_workers.namespace, store_backend request,
block_size,
frontend_port,
engine_workers.namespace,
store_backend,
request_plane=request_plane,
) )
kv_router.__enter__() kv_router.__enter__()
...@@ -1668,6 +1675,7 @@ def _test_router_decisions_disagg( ...@@ -1668,6 +1675,7 @@ def _test_router_decisions_disagg(
frontend_port: int, frontend_port: int,
test_payload: dict, test_payload: dict,
store_backend: str = "etcd", store_backend: str = "etcd",
request_plane: str = "nats",
): ):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend. """Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend.
...@@ -1707,6 +1715,7 @@ def _test_router_decisions_disagg( ...@@ -1707,6 +1715,7 @@ def _test_router_decisions_disagg(
decode_workers.namespace, decode_workers.namespace,
store_backend, store_backend,
enforce_disagg=True, enforce_disagg=True,
request_plane=request_plane,
) )
kv_router.__enter__() kv_router.__enter__()
...@@ -1834,7 +1843,23 @@ def _test_router_decisions_disagg( ...@@ -1834,7 +1843,23 @@ def _test_router_decisions_disagg(
f"Make sure nvext.extra_fields=['worker_id'] is being processed." f"Make sure nvext.extra_fields=['worker_id'] is being processed."
) )
# Verify all prefill_worker_ids are the same (prefix reuse) # Verify prefix reuse behavior.
#
# In JetStream (KV events enabled) mode, the router learns cache state from KV events.
# With the TCP request plane, we can observe a transient on the *first* request where
# the second request is routed before the first request's KV "stored" events have been
# fully ingested. After ingestion, routing stabilizes.
#
# So for TCP we assert that requests 2-4 converge to the same prefill worker; for NATS
# request plane we keep the stronger assertion that all 4 match.
if request_plane == "tcp":
unique_prefill_ids = set(prefill_ids[1:])
assert len(unique_prefill_ids) == 1, (
f"Expected prefill requests 2-4 to route to the same worker due to prefix reuse, "
f"but found {len(unique_prefill_ids)} unique prefill_worker_ids: {unique_prefill_ids}. "
f"Full list: {prefill_ids}"
)
else:
unique_prefill_ids = set(prefill_ids) unique_prefill_ids = set(prefill_ids)
assert len(unique_prefill_ids) == 1, ( assert len(unique_prefill_ids) == 1, (
f"Expected all prefill requests to route to the same worker due to prefix reuse, " f"Expected all prefill requests to route to the same worker due to prefix reuse, "
...@@ -1901,11 +1926,30 @@ def _test_router_decisions( ...@@ -1901,11 +1926,30 @@ def _test_router_decisions(
# Use async to manage the test flow # Use async to manage the test flow
async def test_sync(): async def test_sync():
# Calculate expected number of instances
# With data parallelism:
# - vLLM/SGLang: each DP rank registers as a separate instance
# - Mockers: all DP ranks share the same worker instance ID (instance_ids returns worker IDs)
if test_dp_rank:
if (
hasattr(engine_workers, "data_parallel_size")
and engine_workers.data_parallel_size is not None
):
# vLLM/SGLang: each DP rank registers as a separate instance
expected_num_instances = (
engine_workers.num_workers * engine_workers.data_parallel_size
)
else:
# Mockers with dp_size or no DP: instance_ids() returns worker IDs
expected_num_instances = engine_workers.num_workers
else:
expected_num_instances = engine_workers.num_workers
# Wait for workers to be ready and get their instance IDs # Wait for workers to be ready and get their instance IDs
worker_ids = await wait_for_workers_ready( worker_ids = await wait_for_workers_ready(
endpoint, endpoint,
kv_push_router, kv_push_router,
expected_num_workers=engine_workers.num_workers, expected_num_workers=expected_num_instances,
model_name=model_name, model_name=model_name,
) )
logger.info(f"Workers ready: {worker_ids}") logger.info(f"Workers ready: {worker_ids}")
...@@ -1971,7 +2015,7 @@ def _test_router_decisions( ...@@ -1971,7 +2015,7 @@ def _test_router_decisions(
) )
# Wait a bit between requests # Wait a bit between requests
await asyncio.sleep(0.5) await asyncio.sleep(2)
# Wait for final synchronization (especially important for DP) # Wait for final synchronization (especially important for DP)
if test_dp_rank: if test_dp_rank:
......
...@@ -177,6 +177,8 @@ class MockerProcess: ...@@ -177,6 +177,8 @@ class MockerProcess:
self.num_workers = num_mockers self.num_workers = num_mockers
mocker_args = mocker_args or {} mocker_args = mocker_args or {}
# Store dp_size for DP-aware test functions
self.dp_size = mocker_args.get("dp_size")
command = _build_mocker_command( command = _build_mocker_command(
endpoint=self.endpoint, endpoint=self.endpoint,
...@@ -230,6 +232,7 @@ class DisaggMockerProcess: ...@@ -230,6 +232,7 @@ class DisaggMockerProcess:
mocker_args: Optional[Dict[str, Any]] = None, mocker_args: Optional[Dict[str, Any]] = None,
num_mockers: int = 1, num_mockers: int = 1,
store_backend: str = "etcd", store_backend: str = "etcd",
request_plane: str = "nats",
): ):
if worker_type not in ("prefill", "decode"): if worker_type not in ("prefill", "decode"):
raise ValueError( raise ValueError(
...@@ -258,8 +261,12 @@ class DisaggMockerProcess: ...@@ -258,8 +261,12 @@ class DisaggMockerProcess:
worker_type=worker_type, worker_type=worker_type,
) )
env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane
self._process = ManagedProcess( self._process = ManagedProcess(
command=command, command=command,
env=env,
timeout=60, timeout=60,
display_output=True, display_output=True,
health_check_ports=[], health_check_ports=[],
...@@ -285,16 +292,18 @@ class DisaggMockerProcess: ...@@ -285,16 +292,18 @@ class DisaggMockerProcess:
@pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up @pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
def test_mocker_kv_router( def test_mocker_kv_router(
request, runtime_services_dynamic_ports, predownload_tokenizers request, runtime_services_dynamic_ports, predownload_tokenizers, request_plane
): ):
""" """
Test KV router with multiple mocker engine instances. Test KV router with multiple mocker engine instances.
This test doesn't require GPUs and runs quickly for pre-merge validation. This test doesn't require GPUs and runs quickly for pre-merge validation.
Tests both NATS and TCP request planes.
""" """
# runtime_services starts etcd and nats # runtime_services starts etcd and optionally nats based on request_plane
logger.info("Starting mocker KV router test") logger.info(f"Starting mocker KV router test with request_plane={request_plane}")
# Create mocker args dictionary # Create mocker args dictionary
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE} mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
...@@ -303,13 +312,18 @@ def test_mocker_kv_router( ...@@ -303,13 +312,18 @@ def test_mocker_kv_router(
# Start mocker instances with the new CLI interface # Start mocker instances with the new CLI interface
logger.info(f"Starting {NUM_MOCKERS} mocker instances") logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess( mockers = MockerProcess(
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
request_plane=request_plane,
) )
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__() mockers.__enter__()
# Get unique port for this test # Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0] frontend_port = get_unique_ports(
request, num_ports=1, request_plane=request_plane
)[0]
# Run basic router test (starts router internally and waits for workers to be ready) # Run basic router test (starts router internally and waits for workers to be ready)
_test_router_basic( _test_router_basic(
...@@ -319,6 +333,7 @@ def test_mocker_kv_router( ...@@ -319,6 +333,7 @@ def test_mocker_kv_router(
frontend_port=frontend_port, frontend_port=frontend_port,
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
request_plane=request_plane,
) )
finally: finally:
...@@ -422,8 +437,9 @@ def test_mocker_kv_router_overload_503( ...@@ -422,8 +437,9 @@ def test_mocker_kv_router_overload_503(
@pytest.mark.timeout(22) # ~3x average (~7.10s), rounded up @pytest.mark.timeout(22) # ~3x average (~7.10s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
def test_kv_push_router_bindings( def test_kv_push_router_bindings(
request, runtime_services_dynamic_ports, predownload_tokenizers request, runtime_services_dynamic_ports, predownload_tokenizers, request_plane
): ):
"""Test KvPushRouter Python bindings with mocker engines.""" """Test KvPushRouter Python bindings with mocker engines."""
logger.info("Starting KvPushRouter bindings test") logger.info("Starting KvPushRouter bindings test")
...@@ -433,13 +449,16 @@ def test_kv_push_router_bindings( ...@@ -433,13 +449,16 @@ def test_kv_push_router_bindings(
# Start mocker instances # Start mocker instances
logger.info(f"Starting {NUM_MOCKERS} mocker instances") logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess( mockers = MockerProcess(
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS request,
mocker_args=mocker_args,
num_mockers=NUM_MOCKERS,
request_plane=request_plane,
) )
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__() mockers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime() runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(mockers.namespace) namespace = runtime.namespace(mockers.namespace)
component = namespace.component(mockers.component_name) component = namespace.component(mockers.component_name)
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
...@@ -571,17 +590,21 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -571,17 +590,21 @@ def test_query_instance_id_returns_worker_and_tokens(
mockers.__exit__(None, None, None) mockers.__exit__(None, None, None)
@pytest.mark.parametrize("use_nats_core", [False, True], ids=["jetstream", "nats_core"])
@pytest.mark.timeout(29) # ~3x average (~9.55s), rounded up @pytest.mark.timeout(29) # ~3x average (~9.55s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize("use_nats_core", [False, True], ids=["jetstream", "nats_core"])
def test_router_decisions( def test_router_decisions(
request, runtime_services_dynamic_ports, predownload_tokenizers, use_nats_core request,
runtime_services_dynamic_ports,
predownload_tokenizers,
use_nats_core,
request_plane,
): ):
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
Parameterized to test both JetStream (default) and NATS Core (local indexer) modes. Parameterized to test both JetStream (default) and NATS Core (local indexer) modes.
""" """
# runtime_services_dynamic_ports handles NATS and etcd startup
# runtime_services starts etcd and nats
mode = "NATS Core (local indexer)" if use_nats_core else "JetStream" mode = "NATS Core (local indexer)" if use_nats_core else "JetStream"
logger.info( logger.info(
f"Starting test router prefix reuse and KV events synchronization ({mode})" f"Starting test router prefix reuse and KV events synchronization ({mode})"
...@@ -599,14 +622,19 @@ def test_router_decisions( ...@@ -599,14 +622,19 @@ def test_router_decisions(
logger.info( logger.info(
f"Starting 2 mocker instances with dp_size=4 each (8 total dp ranks), {mode}" f"Starting 2 mocker instances with dp_size=4 each (8 total dp ranks), {mode}"
) )
mockers = MockerProcess(request, mocker_args=mocker_args, num_mockers=2) mockers = MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=2,
request_plane=request_plane,
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
# Initialize mockers # Initialize mockers
mockers.__enter__() mockers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime() runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the mockers # Use the namespace from the mockers
namespace = runtime.namespace(mockers.namespace) namespace = runtime.namespace(mockers.namespace)
component = namespace.component("mocker") component = namespace.component("mocker")
...@@ -621,10 +649,15 @@ def test_router_decisions( ...@@ -621,10 +649,15 @@ def test_router_decisions(
mockers.__exit__(None, None, None) mockers.__exit__(None, None, None)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"]) @pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
@pytest.mark.timeout(59) # ~3x average (~19.51s), rounded up @pytest.mark.timeout(59) # ~3x average (~19.51s), rounded up
def test_router_decisions_disagg( def test_router_decisions_disagg(
request, runtime_services_dynamic_ports, predownload_tokenizers, registration_order request,
runtime_services_dynamic_ports,
predownload_tokenizers,
registration_order,
request_plane,
): ):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup. """Validate KV cache prefix reuse in disaggregated prefill-decode setup.
...@@ -635,6 +668,7 @@ def test_router_decisions_disagg( ...@@ -635,6 +668,7 @@ def test_router_decisions_disagg(
- prefill_first: prefill workers register before decode workers - prefill_first: prefill workers register before decode workers
- decode_first: decode workers register before prefill workers - decode_first: decode workers register before prefill workers
""" """
# runtime_services_dynamic_ports handles NATS and etcd startup
logger.info( logger.info(
f"Starting disaggregated router prefix reuse test " f"Starting disaggregated router prefix reuse test "
f"(registration_order={registration_order})" f"(registration_order={registration_order})"
...@@ -660,6 +694,7 @@ def test_router_decisions_disagg( ...@@ -660,6 +694,7 @@ def test_router_decisions_disagg(
worker_type="prefill", worker_type="prefill",
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=4, num_mockers=4,
request_plane=request_plane,
) )
prefill_workers.__enter__() prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}") logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
...@@ -672,6 +707,7 @@ def test_router_decisions_disagg( ...@@ -672,6 +707,7 @@ def test_router_decisions_disagg(
worker_type="decode", worker_type="decode",
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=4, num_mockers=4,
request_plane=request_plane,
) )
decode_workers.__enter__() decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}") logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
...@@ -684,6 +720,7 @@ def test_router_decisions_disagg( ...@@ -684,6 +720,7 @@ def test_router_decisions_disagg(
worker_type="decode", worker_type="decode",
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=4, num_mockers=4,
request_plane=request_plane,
) )
decode_workers.__enter__() decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}") logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
...@@ -696,6 +733,7 @@ def test_router_decisions_disagg( ...@@ -696,6 +733,7 @@ def test_router_decisions_disagg(
worker_type="prefill", worker_type="prefill",
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=4, num_mockers=4,
request_plane=request_plane,
) )
prefill_workers.__enter__() prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}") logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
...@@ -713,6 +751,7 @@ def test_router_decisions_disagg( ...@@ -713,6 +751,7 @@ def test_router_decisions_disagg(
request=request, request=request,
frontend_port=frontend_port, frontend_port=frontend_port,
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
request_plane=request_plane,
) )
finally: finally:
...@@ -736,6 +775,7 @@ def test_busy_threshold_endpoint( ...@@ -736,6 +775,7 @@ def test_busy_threshold_endpoint(
For now, this test only verifies the endpoint is accessible and returns valid responses. For now, this test only verifies the endpoint is accessible and returns valid responses.
""" """
# runtime_services_dynamic_ports handles NATS and etcd startup
logger.info( logger.info(
f"Starting busy_threshold endpoint test with request_plane={request_plane}" f"Starting busy_threshold endpoint test with request_plane={request_plane}"
) )
......
...@@ -85,6 +85,8 @@ class SGLangProcess: ...@@ -85,6 +85,8 @@ class SGLangProcess:
num_workers: int = 2, num_workers: int = 2,
single_gpu: bool = False, single_gpu: bool = False,
data_parallel_size: Optional[int] = None, data_parallel_size: Optional[int] = None,
request_plane: str = "tcp",
store_backend: str = "etcd",
): ):
"""Initialize SGLang workers with dynamo integration. """Initialize SGLang workers with dynamo integration.
...@@ -99,6 +101,8 @@ class SGLangProcess: ...@@ -99,6 +101,8 @@ class SGLangProcess:
num_workers: Number of SGLang worker processes num_workers: Number of SGLang worker processes
single_gpu: If True, all workers share GPU 0 single_gpu: If True, all workers share GPU 0
data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size) data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size)
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
""" """
# Generate unique namespace for isolation # Generate unique namespace for isolation
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
...@@ -106,7 +110,9 @@ class SGLangProcess: ...@@ -106,7 +110,9 @@ class SGLangProcess:
self.component_name = "backend" self.component_name = "backend"
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate" self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_workers = num_workers self.num_workers = num_workers
self.data_parallel_size = data_parallel_size
self.worker_processes = [] self.worker_processes = []
self.store_backend = store_backend
if sglang_args is None: if sglang_args is None:
sglang_args = {} sglang_args = {}
...@@ -175,13 +181,18 @@ class SGLangProcess: ...@@ -175,13 +181,18 @@ class SGLangProcess:
command.extend(["--kv-events-config", kv_events_config]) command.extend(["--kv-events-config", kv_events_config])
env = os.environ.copy() # Copy parent environment env = os.environ.copy() # Copy parent environment
env.update( env_vars = {
{
"CUDA_VISIBLE_DEVICES": gpu_device, "CUDA_VISIBLE_DEVICES": gpu_device,
"DYN_NAMESPACE": self.namespace, "DYN_NAMESPACE": self.namespace,
"DYN_REQUEST_PLANE": request_plane,
"PYTHONHASHSEED": "0", # for deterministic event id's "PYTHONHASHSEED": "0", # for deterministic event id's
} }
)
# Add DYN_FILE_KV if using file storage backend
if self.store_backend == "file" and "DYN_FILE_KV" in os.environ:
env_vars["DYN_FILE_KV"] = os.environ["DYN_FILE_KV"]
env.update(env_vars)
# Create managed process for the worker # Create managed process for the worker
process = ManagedProcess( process = ManagedProcess(
...@@ -302,17 +313,25 @@ class SGLangProcess: ...@@ -302,17 +313,25 @@ class SGLangProcess:
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.timeout(150) # ~3x average (~46s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~46s/test), rounded up
def test_sglang_kv_router_basic( def test_sglang_kv_router_basic(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
): ):
""" """
Quick e2e sanity test for KV router with SGLang engine instances. Quick e2e sanity test for KV router with SGLang engine instances.
Tests both NATS and TCP request planes.
""" """
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
N_SGLANG_WORKERS = 2 N_SGLANG_WORKERS = 2
logger.info(f"Starting SGLang KV router test with {N_SGLANG_WORKERS} workers") logger.info(
f"Starting SGLang KV router test with {N_SGLANG_WORKERS} workers using request_plane={request_plane}"
)
try: try:
# Start SGLang workers # Start SGLang workers
...@@ -322,6 +341,7 @@ def test_sglang_kv_router_basic( ...@@ -322,6 +341,7 @@ def test_sglang_kv_router_basic(
sglang_args=SGLANG_ARGS, sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS, num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
) )
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__() sglang_workers.__enter__()
...@@ -337,6 +357,7 @@ def test_sglang_kv_router_basic( ...@@ -337,6 +357,7 @@ def test_sglang_kv_router_basic(
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
frontend_timeout=180, # 3 minutes should be plenty for TinyLlama frontend_timeout=180, # 3 minutes should be plenty for TinyLlama
store_backend="etcd", # Explicit for clarity store_backend="etcd", # Explicit for clarity
request_plane=request_plane,
) )
finally: finally:
...@@ -348,8 +369,13 @@ def test_sglang_kv_router_basic( ...@@ -348,8 +369,13 @@ def test_sglang_kv_router_basic(
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.skip(reason="Broken by sglang changes") @pytest.mark.skip(reason="Broken by sglang changes")
# TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged # TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
def test_router_decisions_sglang_multiple_workers( def test_router_decisions_sglang_multiple_workers(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
): ):
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting SGLang router prefix reuse test with two workers") logger.info("Starting SGLang router prefix reuse test with two workers")
...@@ -363,6 +389,7 @@ def test_router_decisions_sglang_multiple_workers( ...@@ -363,6 +389,7 @@ def test_router_decisions_sglang_multiple_workers(
sglang_args=SGLANG_ARGS, sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS, num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0 single_gpu=True, # Worker uses GPU 0
request_plane=request_plane,
) )
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
...@@ -370,7 +397,7 @@ def test_router_decisions_sglang_multiple_workers( ...@@ -370,7 +397,7 @@ def test_router_decisions_sglang_multiple_workers(
sglang_workers.__enter__() sglang_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime() runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(sglang_workers.namespace) namespace = runtime.namespace(sglang_workers.namespace)
component = namespace.component("backend") component = namespace.component("backend")
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
...@@ -386,9 +413,14 @@ def test_router_decisions_sglang_multiple_workers( ...@@ -386,9 +413,14 @@ def test_router_decisions_sglang_multiple_workers(
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.timeout(600) # 10 min max (multi-GPU + DP startup variance) @pytest.mark.timeout(600) # 10 min max (multi-GPU + DP startup variance)
def test_router_decisions_sglang_dp( def test_router_decisions_sglang_dp(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
): ):
"""Validate KV cache prefix reuse with SGLang by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse with SGLang by sending progressive requests with overlapping prefixes.
Same flow as test_router_decisions_sglang_multiple_workers; force first request to (worker_id, dp_rank=1). Same flow as test_router_decisions_sglang_multiple_workers; force first request to (worker_id, dp_rank=1).
...@@ -408,12 +440,13 @@ def test_router_decisions_sglang_dp( ...@@ -408,12 +440,13 @@ def test_router_decisions_sglang_dp(
num_workers=N_WORKERS, # Ignored when data_parallel_size is set num_workers=N_WORKERS, # Ignored when data_parallel_size is set
single_gpu=False, single_gpu=False,
data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank) data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank)
request_plane=request_plane,
) )
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__() sglang_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime() runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the SGLang workers # Use the namespace from the SGLang workers
namespace = runtime.namespace(sglang_workers.namespace) namespace = runtime.namespace(sglang_workers.namespace)
component = namespace.component("backend") # endpoint is backend.generate component = namespace.component("backend") # endpoint is backend.generate
...@@ -431,15 +464,39 @@ def test_router_decisions_sglang_dp( ...@@ -431,15 +464,39 @@ def test_router_decisions_sglang_dp(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.parametrize(
"store_backend,use_nats_core,request_plane",
[
("etcd", False, "nats"), # JetStream mode
# ("etcd", True, "tcp"), # ignored, needs unconditional nats_client
# ("file", False, "nats"), # File backend - TODO: investigate file backend support for SGLang
],
ids=["jetstream"], # "nats_core" and "file" commented out
)
@pytest.mark.timeout(150) # ~3x average (~46s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~46s/test), rounded up
def test_sglang_indexers_sync( def test_sglang_indexers_sync(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
file_storage_backend,
set_ucx_tls_no_mm,
store_backend,
use_nats_core,
request_plane,
): ):
""" """
Test that two KV routers have synchronized indexer states after processing requests Test that two KV routers have synchronized indexer states after processing requests
with SGLang workers. This test verifies that both routers converge to the same internal state. with SGLang workers. This test verifies that both routers converge to the same internal state.
Tests with configuration:
- jetstream: etcd backend, JetStream for KV events, NATS request plane
""" """
logger.info("Starting SGLang indexers sync test") # runtime_services_dynamic_ports handles NATS and etcd startup
logger.info(
f"Starting SGLang indexers sync test: store_backend={store_backend}, "
f"use_nats_core={use_nats_core}, request_plane={request_plane}"
)
N_SGLANG_WORKERS = 2 N_SGLANG_WORKERS = 2
try: try:
...@@ -450,6 +507,8 @@ def test_sglang_indexers_sync( ...@@ -450,6 +507,8 @@ def test_sglang_indexers_sync(
sglang_args=SGLANG_ARGS, sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS, num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend,
) )
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__() sglang_workers.__enter__()
...@@ -461,7 +520,8 @@ def test_sglang_indexers_sync( ...@@ -461,7 +520,8 @@ def test_sglang_indexers_sync(
block_size=PAGE_SIZE, block_size=PAGE_SIZE,
model_name=MODEL_NAME, model_name=MODEL_NAME,
num_workers=N_SGLANG_WORKERS, num_workers=N_SGLANG_WORKERS,
store_backend="etcd", store_backend=store_backend,
request_plane=request_plane,
) )
logger.info("SGLang indexers sync test completed successfully") logger.info("SGLang indexers sync test completed successfully")
......
...@@ -82,6 +82,8 @@ class TRTLLMProcess: ...@@ -82,6 +82,8 @@ class TRTLLMProcess:
trtllm_args: Optional[Dict[str, Any]] = None, trtllm_args: Optional[Dict[str, Any]] = None,
num_workers: int = 2, num_workers: int = 2,
single_gpu: bool = False, single_gpu: bool = False,
request_plane: str = "tcp",
store_backend: str = "etcd",
): ):
"""Initialize TRT-LLM workers with dynamo integration. """Initialize TRT-LLM workers with dynamo integration.
...@@ -94,6 +96,8 @@ class TRTLLMProcess: ...@@ -94,6 +96,8 @@ class TRTLLMProcess:
- max_seq_len: Maximum sequence length (optional) - max_seq_len: Maximum sequence length (optional)
num_workers: Number of TRT-LLM worker processes num_workers: Number of TRT-LLM worker processes
single_gpu: If True, all workers share GPU 0 single_gpu: If True, all workers share GPU 0
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
Note: TRT-LLM doesn't support data parallelism like vLLM (dp_rank is always 0). Note: TRT-LLM doesn't support data parallelism like vLLM (dp_rank is always 0).
Tensor parallelism (TP) is supported but creates 1 worker spanning multiple GPUs, Tensor parallelism (TP) is supported but creates 1 worker spanning multiple GPUs,
...@@ -106,6 +110,7 @@ class TRTLLMProcess: ...@@ -106,6 +110,7 @@ class TRTLLMProcess:
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate" self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_workers = num_workers self.num_workers = num_workers
self.worker_processes = [] self.worker_processes = []
self.store_backend = store_backend
if trtllm_args is None: if trtllm_args is None:
trtllm_args = {} trtllm_args = {}
...@@ -154,15 +159,20 @@ class TRTLLMProcess: ...@@ -154,15 +159,20 @@ class TRTLLMProcess:
system_port = 8081 + worker_idx system_port = 8081 + worker_idx
env = os.environ.copy() # Copy parent environment env = os.environ.copy() # Copy parent environment
env.update( env_vars = {
{
"CUDA_VISIBLE_DEVICES": gpu_device, "CUDA_VISIBLE_DEVICES": gpu_device,
"DYN_NAMESPACE": self.namespace, "DYN_NAMESPACE": self.namespace,
"DYN_REQUEST_PLANE": request_plane,
"PYTHONHASHSEED": "0", # for deterministic event id's "PYTHONHASHSEED": "0", # for deterministic event id's
# Set unique system port for each worker to avoid port conflicts # Set unique system port for each worker to avoid port conflicts
"DYN_SYSTEM_PORT": str(system_port), "DYN_SYSTEM_PORT": str(system_port),
} }
)
# Add DYN_FILE_KV if using file storage backend
if self.store_backend == "file" and "DYN_FILE_KV" in os.environ:
env_vars["DYN_FILE_KV"] = os.environ["DYN_FILE_KV"]
env.update(env_vars)
# Create managed process for the worker # Create managed process for the worker
process = ManagedProcess( process = ManagedProcess(
...@@ -276,17 +286,25 @@ class TRTLLMProcess: ...@@ -276,17 +286,25 @@ class TRTLLMProcess:
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up
def test_trtllm_kv_router_basic( def test_trtllm_kv_router_basic(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
): ):
""" """
Quick e2e sanity test for KV router with TRT-LLM engine instances. Quick e2e sanity test for KV router with TRT-LLM engine instances.
Tests both NATS and TCP request planes.
""" """
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
N_TRTLLM_WORKERS = 2 N_TRTLLM_WORKERS = 2
logger.info(f"Starting TRT-LLM KV router test with {N_TRTLLM_WORKERS} workers") logger.info(
f"Starting TRT-LLM KV router test with {N_TRTLLM_WORKERS} workers using request_plane={request_plane}"
)
try: try:
# Start TRT-LLM workers # Start TRT-LLM workers
...@@ -296,6 +314,7 @@ def test_trtllm_kv_router_basic( ...@@ -296,6 +314,7 @@ def test_trtllm_kv_router_basic(
trtllm_args=TRTLLM_ARGS, trtllm_args=TRTLLM_ARGS,
num_workers=N_TRTLLM_WORKERS, num_workers=N_TRTLLM_WORKERS,
single_gpu=True, # fit workers into one GPU single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
) )
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
trtllm_workers.__enter__() trtllm_workers.__enter__()
...@@ -311,6 +330,7 @@ def test_trtllm_kv_router_basic( ...@@ -311,6 +330,7 @@ def test_trtllm_kv_router_basic(
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
frontend_timeout=180, # 3 minutes should be plenty for TinyLlama frontend_timeout=180, # 3 minutes should be plenty for TinyLlama
store_backend="etcd", # Explicit for clarity store_backend="etcd", # Explicit for clarity
request_plane=request_plane,
) )
finally: finally:
...@@ -320,9 +340,14 @@ def test_trtllm_kv_router_basic( ...@@ -320,9 +340,14 @@ def test_trtllm_kv_router_basic(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up
def test_router_decisions_trtllm_multiple_workers( def test_router_decisions_trtllm_multiple_workers(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
): ):
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting TRT-LLM router prefix reuse test with two workers") logger.info("Starting TRT-LLM router prefix reuse test with two workers")
...@@ -338,6 +363,7 @@ def test_router_decisions_trtllm_multiple_workers( ...@@ -338,6 +363,7 @@ def test_router_decisions_trtllm_multiple_workers(
trtllm_args=TRTLLM_ARGS, trtllm_args=TRTLLM_ARGS,
num_workers=N_WORKERS, num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0 single_gpu=True, # Worker uses GPU 0
request_plane=request_plane,
) )
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
...@@ -345,7 +371,7 @@ def test_router_decisions_trtllm_multiple_workers( ...@@ -345,7 +371,7 @@ def test_router_decisions_trtllm_multiple_workers(
trtllm_workers.__enter__() trtllm_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime() runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(trtllm_workers.namespace) namespace = runtime.namespace(trtllm_workers.namespace)
component = namespace.component("tensorrt_llm") component = namespace.component("tensorrt_llm")
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
...@@ -368,14 +394,38 @@ def test_router_decisions_trtllm_multiple_workers( ...@@ -368,14 +394,38 @@ def test_router_decisions_trtllm_multiple_workers(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up
@pytest.mark.parametrize(
"store_backend,use_nats_core,request_plane",
[
("etcd", False, "nats"), # JetStream mode
# ("etcd", True, "tcp"), # ignored, needs unconditional nats_client
# ("file", False, "nats"), # File backend - TODO: investigate file backend support for TRT-LLM
],
ids=["jetstream"], # "nats_core" and "file" commented out
)
def test_trtllm_indexers_sync( def test_trtllm_indexers_sync(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
file_storage_backend,
set_ucx_tls_no_mm,
store_backend,
use_nats_core,
request_plane,
): ):
""" """
Test that two KV routers have synchronized indexer states after processing requests Test that two KV routers have synchronized indexer states after processing requests
with TRT-LLM workers. This test verifies that both routers converge to the same internal state. with TRT-LLM workers. This test verifies that both routers converge to the same internal state.
Tests with configuration:
- jetstream: etcd backend, JetStream for KV events, NATS request plane
""" """
logger.info("Starting TRT-LLM indexers sync test") # runtime_services_dynamic_ports handles NATS and etcd startup
logger.info(
f"Starting TRT-LLM indexers sync test: store_backend={store_backend}, "
f"use_nats_core={use_nats_core}, request_plane={request_plane}"
)
N_TRTLLM_WORKERS = 2 N_TRTLLM_WORKERS = 2
try: try:
...@@ -386,6 +436,8 @@ def test_trtllm_indexers_sync( ...@@ -386,6 +436,8 @@ def test_trtllm_indexers_sync(
trtllm_args=TRTLLM_ARGS, trtllm_args=TRTLLM_ARGS,
num_workers=N_TRTLLM_WORKERS, num_workers=N_TRTLLM_WORKERS,
single_gpu=True, # fit workers into one GPU single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend,
) )
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
trtllm_workers.__enter__() trtllm_workers.__enter__()
...@@ -397,7 +449,8 @@ def test_trtllm_indexers_sync( ...@@ -397,7 +449,8 @@ def test_trtllm_indexers_sync(
block_size=TRTLLM_BLOCK_SIZE, block_size=TRTLLM_BLOCK_SIZE,
model_name=MODEL_NAME, model_name=MODEL_NAME,
num_workers=N_TRTLLM_WORKERS, num_workers=N_TRTLLM_WORKERS,
store_backend="etcd", store_backend=store_backend,
request_plane=request_plane,
) )
logger.info("TRT-LLM indexers sync test completed successfully") logger.info("TRT-LLM indexers sync test completed successfully")
......
...@@ -85,6 +85,8 @@ class VLLMProcess: ...@@ -85,6 +85,8 @@ class VLLMProcess:
num_workers: int = 2, num_workers: int = 2,
single_gpu: bool = False, single_gpu: bool = False,
data_parallel_size: Optional[int] = None, data_parallel_size: Optional[int] = None,
request_plane: str = "tcp",
store_backend: str = "etcd",
): ):
"""Initialize vLLM workers with dynamo integration. """Initialize vLLM workers with dynamo integration.
...@@ -100,6 +102,8 @@ class VLLMProcess: ...@@ -100,6 +102,8 @@ class VLLMProcess:
num_workers: Number of vLLM worker processes num_workers: Number of vLLM worker processes
single_gpu: If True, all workers share GPU 0 single_gpu: If True, all workers share GPU 0
data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size) data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size)
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
""" """
# Generate unique namespace for isolation # Generate unique namespace for isolation
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
...@@ -107,7 +111,9 @@ class VLLMProcess: ...@@ -107,7 +111,9 @@ class VLLMProcess:
self.component_name = "backend" self.component_name = "backend"
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate" self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_workers = num_workers self.num_workers = num_workers
self.data_parallel_size = data_parallel_size
self.worker_processes = [] self.worker_processes = []
self.store_backend = store_backend
if vllm_args is None: if vllm_args is None:
vllm_args = {} vllm_args = {}
...@@ -190,15 +196,20 @@ class VLLMProcess: ...@@ -190,15 +196,20 @@ class VLLMProcess:
) )
env = os.environ.copy() # Copy parent environment env = os.environ.copy() # Copy parent environment
env.update( env_vars = {
{
"CUDA_VISIBLE_DEVICES": gpu_device, "CUDA_VISIBLE_DEVICES": gpu_device,
"DYN_NAMESPACE": self.namespace, "DYN_NAMESPACE": self.namespace,
"DYN_REQUEST_PLANE": request_plane,
"DYN_VLLM_KV_EVENT_PORT": str(20080 + worker_idx), "DYN_VLLM_KV_EVENT_PORT": str(20080 + worker_idx),
"VLLM_NIXL_SIDE_CHANNEL_PORT": str(20090 + worker_idx), "VLLM_NIXL_SIDE_CHANNEL_PORT": str(20090 + worker_idx),
"PYTHONHASHSEED": "0", # for deterministic event id's "PYTHONHASHSEED": "0", # for deterministic event id's
} }
)
# Add DYN_FILE_KV if using file storage backend
if self.store_backend == "file" and "DYN_FILE_KV" in os.environ:
env_vars["DYN_FILE_KV"] = os.environ["DYN_FILE_KV"]
env.update(env_vars)
# Create managed process for the worker # Create managed process for the worker
process = ManagedProcess( process = ManagedProcess(
...@@ -318,16 +329,24 @@ class VLLMProcess: ...@@ -318,16 +329,24 @@ class VLLMProcess:
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
def test_vllm_kv_router_basic( def test_vllm_kv_router_basic(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
): ):
""" """
Quick e2e sanity test for KV router with vLLM engine instances. Quick e2e sanity test for KV router with vLLM engine instances.
Tests both NATS and TCP request planes.
""" """
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
N_VLLM_WORKERS = 2 N_VLLM_WORKERS = 2
logger.info(f"Starting vLLM KV router test with {N_VLLM_WORKERS} workers") logger.info(
f"Starting vLLM KV router test with {N_VLLM_WORKERS} workers using request_plane={request_plane}"
)
try: try:
# Start vLLM workers # Start vLLM workers
...@@ -337,6 +356,7 @@ def test_vllm_kv_router_basic( ...@@ -337,6 +356,7 @@ def test_vllm_kv_router_basic(
vllm_args=VLLM_ARGS, vllm_args=VLLM_ARGS,
num_workers=N_VLLM_WORKERS, num_workers=N_VLLM_WORKERS,
single_gpu=True, # fit workers into one GPU single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
) )
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
vllm_workers.__enter__() vllm_workers.__enter__()
...@@ -352,6 +372,7 @@ def test_vllm_kv_router_basic( ...@@ -352,6 +372,7 @@ def test_vllm_kv_router_basic(
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
frontend_timeout=180, # 3 minutes should be plenty for TinyLlama frontend_timeout=180, # 3 minutes should be plenty for TinyLlama
store_backend="etcd", # Explicit for clarity store_backend="etcd", # Explicit for clarity
request_plane=request_plane,
) )
finally: finally:
...@@ -362,8 +383,13 @@ def test_vllm_kv_router_basic( ...@@ -362,8 +383,13 @@ def test_vllm_kv_router_basic(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
def test_router_decisions_vllm_multiple_workers( def test_router_decisions_vllm_multiple_workers(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
): ):
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting vLLM router prefix reuse test with two workers") logger.info("Starting vLLM router prefix reuse test with two workers")
...@@ -377,6 +403,7 @@ def test_router_decisions_vllm_multiple_workers( ...@@ -377,6 +403,7 @@ def test_router_decisions_vllm_multiple_workers(
vllm_args=VLLM_ARGS, vllm_args=VLLM_ARGS,
num_workers=N_WORKERS, num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0 single_gpu=True, # Worker uses GPU 0
request_plane=request_plane,
) )
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
...@@ -384,7 +411,7 @@ def test_router_decisions_vllm_multiple_workers( ...@@ -384,7 +411,7 @@ def test_router_decisions_vllm_multiple_workers(
vllm_workers.__enter__() vllm_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime() runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(vllm_workers.namespace) namespace = runtime.namespace(vllm_workers.namespace)
component = namespace.component("backend") component = namespace.component("backend")
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
...@@ -400,9 +427,14 @@ def test_router_decisions_vllm_multiple_workers( ...@@ -400,9 +427,14 @@ def test_router_decisions_vllm_multiple_workers(
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.timeout(600) # 10 min max (multi-GPU + DP startup variance) @pytest.mark.timeout(600) # 10 min max (multi-GPU + DP startup variance)
def test_router_decisions_vllm_dp( def test_router_decisions_vllm_dp(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
): ):
"""Validate KV cache prefix reuse with vLLM by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse with vLLM by sending progressive requests with overlapping prefixes.
Same flow as test_router_decisions_vllm_multiple_workers; force first request to (worker_id, dp_rank=1). Same flow as test_router_decisions_vllm_multiple_workers; force first request to (worker_id, dp_rank=1).
...@@ -422,12 +454,13 @@ def test_router_decisions_vllm_dp( ...@@ -422,12 +454,13 @@ def test_router_decisions_vllm_dp(
num_workers=N_WORKERS, # Ignored when data_parallel_size is set num_workers=N_WORKERS, # Ignored when data_parallel_size is set
single_gpu=False, single_gpu=False,
data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank) data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank)
request_plane=request_plane,
) )
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
vllm_workers.__enter__() vllm_workers.__enter__()
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime() runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the vLLM workers # Use the namespace from the vLLM workers
namespace = runtime.namespace(vllm_workers.namespace) namespace = runtime.namespace(vllm_workers.namespace)
component = namespace.component("backend") # endpoint is backend.generate component = namespace.component("backend") # endpoint is backend.generate
...@@ -446,14 +479,39 @@ def test_router_decisions_vllm_dp( ...@@ -446,14 +479,39 @@ def test_router_decisions_vllm_dp(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up
@pytest.mark.parametrize(
"store_backend,use_nats_core,request_plane",
[
("etcd", False, "nats"), # JetStream mode
("etcd", True, "tcp"), # nats_core mode
# ("file", False, "nats"), # File backend
],
ids=["jetstream", "tcp_nats_core"],
)
def test_vllm_indexers_sync( def test_vllm_indexers_sync(
request, runtime_services_dynamic_ports, predownload_models, set_ucx_tls_no_mm request,
runtime_services_dynamic_ports,
predownload_models,
file_storage_backend,
set_ucx_tls_no_mm,
store_backend,
use_nats_core,
request_plane,
): ):
""" """
Test that two KV routers have synchronized indexer states after processing requests Test that two KV routers have synchronized indexer states after processing requests
with vLLM workers. This test verifies that both routers converge to the same internal state. with vLLM workers. This test verifies that both routers converge to the same internal state.
Tests with configuration:
- jetstream: etcd backend, JetStream for KV events, NATS request plane
- tcp_nats_core: etcd backend, local indexer with NATS Core, TCP request plane
""" """
logger.info("Starting vLLM indexers sync test") # runtime_services_dynamic_ports handles NATS and etcd startup
logger.info(
f"Starting vLLM indexers sync test: store_backend={store_backend}, "
f"use_nats_core={use_nats_core}, request_plane={request_plane}"
)
N_VLLM_WORKERS = 2 N_VLLM_WORKERS = 2
try: try:
...@@ -464,6 +522,8 @@ def test_vllm_indexers_sync( ...@@ -464,6 +522,8 @@ def test_vllm_indexers_sync(
vllm_args=VLLM_ARGS, vllm_args=VLLM_ARGS,
num_workers=N_VLLM_WORKERS, num_workers=N_VLLM_WORKERS,
single_gpu=True, # fit workers into one GPU single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend,
) )
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
vllm_workers.__enter__() vllm_workers.__enter__()
...@@ -475,7 +535,8 @@ def test_vllm_indexers_sync( ...@@ -475,7 +535,8 @@ def test_vllm_indexers_sync(
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
model_name=MODEL_NAME, model_name=MODEL_NAME,
num_workers=N_VLLM_WORKERS, num_workers=N_VLLM_WORKERS,
store_backend="etcd", store_backend=store_backend,
request_plane=request_plane,
) )
logger.info("vLLM indexers sync test completed successfully") logger.info("vLLM indexers sync test completed successfully")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment