Unverified Commit 0d3ff440 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: enable toggling kv events pub/sub (currently nats based) with --no-kv-events flag (#5237)

parent bb8eaa23
...@@ -146,11 +146,13 @@ def parse_args(): ...@@ -146,11 +146,13 @@ def parse_args():
help="KV Router: Temperature for worker sampling via softmax. Higher values promote more randomness, and 0 fallbacks to deterministic.", help="KV Router: Temperature for worker sampling via softmax. Higher values promote more randomness, and 0 fallbacks to deterministic.",
) )
parser.add_argument( parser.add_argument(
"--no-kv-events", "--kv-events",
action="store_false", action=argparse.BooleanOptionalAction,
dest="use_kv_events", dest="use_kv_events",
default=os.environ.get("DYN_KV_EVENTS", "true").lower() != "false", default=(
help="KV Router: Disable KV events. When set, the router predicts cache state based on routing decisions with TTL-based expiration and pruning, rather than receiving events from workers. By default, KV events are enabled.", os.environ.get("DYN_KV_EVENTS", "true").lower() == "true"
), # default is true
help="KV Router: Enable/disable KV events. Use --kv-events to enable (default, router receives cache state events from workers) or --no-kv-events to disable (router predicts cache state based on routing decisions).",
) )
parser.add_argument( parser.add_argument(
"--router-ttl", "--router-ttl",
...@@ -325,8 +327,11 @@ async def async_main(): ...@@ -325,8 +327,11 @@ async def async_main():
if prefix: if prefix:
os.environ["DYN_METRICS_PREFIX"] = flags.metrics_prefix os.environ["DYN_METRICS_PREFIX"] = flags.metrics_prefix
# Enable NATS for KV router mode when kv_events are used (when --no-kv-events is not set)
enable_nats = (flags.router_mode == "kv") and flags.use_kv_events
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, flags.store_kv, flags.request_plane) runtime = DistributedRuntime(loop, flags.store_kv, flags.request_plane, enable_nats)
def signal_handler(): def signal_handler():
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime))
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import contextlib import contextlib
import json
import logging import logging
import os import os
import socket import socket
...@@ -157,6 +158,8 @@ class DynamoArgs: ...@@ -157,6 +158,8 @@ class DynamoArgs:
dump_config_to: Optional[str] = None dump_config_to: Optional[str] = None
# local indexer option # local indexer option
enable_local_indexer: bool = False enable_local_indexer: bool = False
# Whether to enable NATS for KV events (derived from server_args.kv_events_config)
use_kv_events: bool = False
class DisaggregationMode(Enum): class DisaggregationMode(Enum):
...@@ -469,27 +472,6 @@ async def parse_args(args: list[str]) -> Config: ...@@ -469,27 +472,6 @@ async def parse_args(args: list[str]) -> Config:
f"Custom Jinja template file not found: {expanded_template_path}" f"Custom Jinja template file not found: {expanded_template_path}"
) )
dynamo_args = DynamoArgs(
namespace=parsed_namespace,
component=parsed_component_name,
endpoint=parsed_endpoint_name,
migration_limit=parsed_args.migration_limit,
store_kv=parsed_args.store_kv,
request_plane=parsed_args.request_plane,
tool_call_parser=tool_call_parser,
reasoning_parser=reasoning_parser,
custom_jinja_template=expanded_template_path,
dyn_endpoint_types=parsed_args.dyn_endpoint_types,
use_sglang_tokenizer=parsed_args.use_sglang_tokenizer,
multimodal_processor=parsed_args.multimodal_processor,
multimodal_encode_worker=parsed_args.multimodal_encode_worker,
multimodal_worker=parsed_args.multimodal_worker,
embedding_worker=parsed_args.embedding_worker,
dump_config_to=parsed_args.dump_config_to,
enable_local_indexer=str(parsed_args.enable_local_indexer).lower() == "true",
)
logging.debug(f"Dynamo args: {dynamo_args}")
model_path = parsed_args.model_path model_path = parsed_args.model_path
# Name the model # Name the model
if not parsed_args.served_model_name: if not parsed_args.served_model_name:
...@@ -520,6 +502,43 @@ async def parse_args(args: list[str]) -> Config: ...@@ -520,6 +502,43 @@ async def parse_args(args: list[str]) -> Config:
) )
server_args.skip_tokenizer_init = True server_args.skip_tokenizer_init = True
# Derive use_kv_events from server_args.kv_events_config
# Check that kv_events_config exists AND publisher is not "null" ("zmq" or any future publishers)
use_kv_events = False
if server_args.kv_events_config:
try:
kv_cfg = json.loads(server_args.kv_events_config)
use_kv_events = kv_cfg.get("publisher", "null") != "null"
except json.JSONDecodeError:
logging.warning(
f"Failed to parse kv_events_config: {server_args.kv_events_config}"
)
logging.info(
f"Derived use_kv_events={use_kv_events} from kv_events_config={server_args.kv_events_config}"
)
dynamo_args = DynamoArgs(
namespace=parsed_namespace,
component=parsed_component_name,
endpoint=parsed_endpoint_name,
migration_limit=parsed_args.migration_limit,
store_kv=parsed_args.store_kv,
request_plane=parsed_args.request_plane,
tool_call_parser=tool_call_parser,
reasoning_parser=reasoning_parser,
custom_jinja_template=expanded_template_path,
dyn_endpoint_types=parsed_args.dyn_endpoint_types,
use_sglang_tokenizer=parsed_args.use_sglang_tokenizer,
multimodal_processor=parsed_args.multimodal_processor,
multimodal_encode_worker=parsed_args.multimodal_encode_worker,
multimodal_worker=parsed_args.multimodal_worker,
embedding_worker=parsed_args.embedding_worker,
dump_config_to=parsed_args.dump_config_to,
enable_local_indexer=str(parsed_args.enable_local_indexer).lower() == "true",
use_kv_events=use_kv_events,
)
logging.debug(f"Dynamo args: {dynamo_args}")
return Config(server_args, dynamo_args) return Config(server_args, dynamo_args)
......
...@@ -70,8 +70,12 @@ async def worker(): ...@@ -70,8 +70,12 @@ async def worker():
dump_config(config.dynamo_args.dump_config_to, config) dump_config(config.dynamo_args.dump_config_to, config)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# Enable NATS based on use_kv_events flag (derived from kv_events_config)
runtime = DistributedRuntime( runtime = DistributedRuntime(
loop, config.dynamo_args.store_kv, config.dynamo_args.request_plane loop,
config.dynamo_args.store_kv,
config.dynamo_args.request_plane,
config.dynamo_args.use_kv_events,
) )
def signal_handler(): def signal_handler():
......
...@@ -113,7 +113,10 @@ async def worker(): ...@@ -113,7 +113,10 @@ async def worker():
config = cmd_line_args() config = cmd_line_args()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.store_kv, config.request_plane) # Enable NATS based on use_kv_events flag (derived from publish_events_and_metrics)
runtime = DistributedRuntime(
loop, config.store_kv, config.request_plane, config.use_kv_events
)
# Set up signal handler for graceful shutdown # Set up signal handler for graceful shutdown
def signal_handler(): def signal_handler():
......
...@@ -62,6 +62,8 @@ class Config: ...@@ -62,6 +62,8 @@ class Config:
self.store_kv: str = "" self.store_kv: str = ""
self.request_plane: str = "" self.request_plane: str = ""
self.enable_local_indexer: bool = False self.enable_local_indexer: bool = False
# Whether to enable NATS for KV events (derived from publish_events_and_metrics)
self.use_kv_events: bool = False
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
...@@ -95,7 +97,8 @@ class Config: ...@@ -95,7 +97,8 @@ class Config:
f"custom_jinja_template={self.custom_jinja_template}, " f"custom_jinja_template={self.custom_jinja_template}, "
f"store_kv={self.store_kv}, " f"store_kv={self.store_kv}, "
f"request_plane={self.request_plane}, " f"request_plane={self.request_plane}, "
f"enable_local_indexer={self.enable_local_indexer}" f"enable_local_indexer={self.enable_local_indexer}, "
f"use_kv_events={self.use_kv_events}"
) )
...@@ -375,6 +378,8 @@ def cmd_line_args(): ...@@ -375,6 +378,8 @@ def cmd_line_args():
config.store_kv = args.store_kv config.store_kv = args.store_kv
config.request_plane = args.request_plane config.request_plane = args.request_plane
config.enable_local_indexer = str(args.enable_local_indexer).lower() == "true" config.enable_local_indexer = str(args.enable_local_indexer).lower() == "true"
# Derive use_kv_events from publish_events_and_metrics
config.use_kv_events = config.publish_events_and_metrics
# Handle custom jinja template path expansion (environment variables and home directory) # Handle custom jinja template path expansion (environment variables and home directory)
if args.custom_jinja_template: if args.custom_jinja_template:
......
...@@ -84,6 +84,9 @@ class Config: ...@@ -84,6 +84,9 @@ class Config:
# Use vLLM's tokenizer for pre/post processing # Use vLLM's tokenizer for pre/post processing
use_vllm_tokenizer: bool = False use_vllm_tokenizer: bool = False
# Whether to enable NATS for KV events (derived from kv_events_config in overwrite_args)
use_kv_events: bool = False
def has_connector(self, connector_name: str) -> bool: def has_connector(self, connector_name: str) -> bool:
""" """
Check if a specific connector is enabled. Check if a specific connector is enabled.
...@@ -394,6 +397,7 @@ def parse_args() -> Config: ...@@ -394,6 +397,7 @@ def parse_args() -> Config:
config.request_plane = args.request_plane config.request_plane = args.request_plane
config.enable_local_indexer = args.enable_local_indexer config.enable_local_indexer = args.enable_local_indexer
config.use_vllm_tokenizer = args.use_vllm_tokenizer config.use_vllm_tokenizer = args.use_vllm_tokenizer
# use_kv_events is set later in overwrite_args() based on kv_events_config
# Validate custom Jinja template file exists if provided # Validate custom Jinja template file exists if provided
if config.custom_jinja_template is not None: if config.custom_jinja_template is not None:
...@@ -564,9 +568,13 @@ def overwrite_args(config): ...@@ -564,9 +568,13 @@ def overwrite_args(config):
if kv_transfer_config: if kv_transfer_config:
defaults["kv_transfer_config"] = kv_transfer_config defaults["kv_transfer_config"] = kv_transfer_config
defaults["kv_events_config"] = create_kv_events_config(config) kv_cfg = create_kv_events_config(config)
defaults["kv_events_config"] = kv_cfg
# Derive use_kv_events from whether kv_events_config is set AND enable_kv_cache_events is True
config.use_kv_events = kv_cfg is not None and kv_cfg.enable_kv_cache_events
logger.info( logger.info(
f"Using kv_events_config for publishing vLLM kv events over zmq: {defaults['kv_events_config']}" f"Using kv_events_config for publishing vLLM kv events over zmq: {kv_cfg} "
f"(use_kv_events={config.use_kv_events})"
) )
logger.debug("Setting Dynamo defaults for vLLM") logger.debug("Setting Dynamo defaults for vLLM")
......
...@@ -64,10 +64,13 @@ async def worker(): ...@@ -64,10 +64,13 @@ async def worker():
config = parse_args() config = parse_args()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, config.store_kv, config.request_plane)
overwrite_args(config) overwrite_args(config)
# Enable NATS based on use_kv_events flag (derived from kv_events_config)
runtime = DistributedRuntime(
loop, config.store_kv, config.request_plane, config.use_kv_events
)
# Set up signal handler for graceful shutdown # Set up signal handler for graceful shutdown
def signal_handler(): def signal_handler():
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime))
...@@ -217,6 +220,13 @@ def setup_kv_event_publisher( ...@@ -217,6 +220,13 @@ def setup_kv_event_publisher(
if config.engine_args.kv_events_config is None: if config.engine_args.kv_events_config is None:
return None return None
# Check if kv_cache_events are explicitly disabled
if not config.engine_args.kv_events_config.enable_kv_cache_events:
logger.info(
"KV event publishing skipped: enable_kv_cache_events=False in kv_events_config"
)
return None
# Get data_parallel_size to create publishers for all dp_ranks # Get data_parallel_size to create publishers for all dp_ranks
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1) data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
kv_publishers = [] kv_publishers = []
......
...@@ -45,11 +45,11 @@ In this section, we explain what happens under the hood when `DistributedRuntime ...@@ -45,11 +45,11 @@ In this section, we explain what happens under the hood when `DistributedRuntime
The hierarchy and naming in etcd and NATS may change over time, and this document might not reflect the latest changes. Regardless of such changes, the main concepts would remain the same. The hierarchy and naming in etcd and NATS may change over time, and this document might not reflect the latest changes. Regardless of such changes, the main concepts would remain the same.
``` ```
- `DistributedRuntime`: When a `DistributedRuntime` object is created, it establishes connections to the following two services: - `DistributedRuntime`: When a `DistributedRuntime` object is created, it establishes connections to the following services:
- etcd (dynamic mode only): for service discovery. In static mode, `DistributedRuntime` can operate without etcd. - etcd (dynamic mode only): for service discovery. In static mode, `DistributedRuntime` can operate without etcd.
- NATS (both static and dynamic mode): for messaging. - NATS (optional): for KV event messaging and router replica sync. NATS is enabled by default but can be disabled via the `enable_nats` parameter (e.g., using `--no-kv-events` flag). When NATS is disabled, the system operates in approximate mode without KV event persistence. Also legacy nats based request_plane is supported.
where etcd and NATS are two global services (there could be multiple etcd and NATS services for high availability). etcd and NATS are global services (there could be multiple instances for high availability).
For etcd, it also creates a primary lease and spin up a background task to keep the lease alive. All objects registered under this `DistributedRuntime` use this lease_id to maintain their life cycle. There is also a cancellation token that is tied to the primary lease. When the cancellation token is triggered or the background task failed, the primary lease is revoked or expired and the kv pairs stored with this lease_id is removed. For etcd, it also creates a primary lease and spin up a background task to keep the lease alive. All objects registered under this `DistributedRuntime` use this lease_id to maintain their life cycle. There is also a cancellation token that is tied to the primary lease. When the cancellation token is triggered or the background task failed, the primary lease is revoked or expired and the kv pairs stored with this lease_id is removed.
- `Namespace`: `Namespace`s are primarily a logical grouping mechanism and is not registered in etcd. It provides the root path for all components under this `Namespace`. - `Namespace`: `Namespace`s are primarily a logical grouping mechanism and is not registered in etcd. It provides the root path for all components under this `Namespace`.
......
...@@ -73,7 +73,7 @@ be operating within your distributed runtime. ...@@ -73,7 +73,7 @@ be operating within your distributed runtime.
The current examples use a hard-coded `namespace`. We will address the `namespace` collisions later. The current examples use a hard-coded `namespace`. We will address the `namespace` collisions later.
All examples require the `etcd` and `nats.io` pre-requisites to be running and available. Most examples require `etcd` for service discovery. `nats.io` is required for KV-aware routing with event tracking; for approximate mode (`--no-kv-events`), NATS is optional.
#### Rust `hello_world` #### Rust `hello_world`
......
...@@ -44,7 +44,7 @@ Dynamo has **two independent communication planes**: ...@@ -44,7 +44,7 @@ Dynamo has **two independent communication planes**:
- **Request plane** (**`DYN_REQUEST_PLANE`**): how **RPC requests** flow between components (frontend → router → worker), via `tcp`, `http`, or `nats`. - **Request plane** (**`DYN_REQUEST_PLANE`**): how **RPC requests** flow between components (frontend → router → worker), via `tcp`, `http`, or `nats`.
- **KV event plane** (currently only **NATS** is supported): how **KV cache events** (and optional router replica sync) are distributed/persisted for KV-aware routing. - **KV event plane** (currently only **NATS** is supported): how **KV cache events** (and optional router replica sync) are distributed/persisted for KV-aware routing.
**Note:** if you are using `tcp` or `http` request plane and choose to use NATS for KV events, you must still configure NATS server using `NATS_SERVER` environment variable, e.g. `NATS_SERVER=nats://nats-hostname:port`. **Note:** If you are using `tcp` or `http` request plane with KV events enabled (default), NATS is automatically initialized. You can optionally configure `NATS_SERVER` environment variable (e.g., `NATS_SERVER=nats://nats-hostname:port`) to specify a custom NATS server; otherwise, it defaults to `localhost:4222`. To completely disable NATS, use `--no-kv-events` on the frontend.
Because they are independent, you can mix them. Because they are independent, you can mix them.
...@@ -100,7 +100,7 @@ DYN_REQUEST_PLANE=tcp python -m dynamo.vllm --model Qwen/Qwen3-0.6B ...@@ -100,7 +100,7 @@ DYN_REQUEST_PLANE=tcp python -m dynamo.vllm --model Qwen/Qwen3-0.6B
**When to use TCP:** **When to use TCP:**
- Simple deployments with direct service-to-service communication (e.g. frontend to backend) - Simple deployments with direct service-to-service communication (e.g. frontend to backend)
- Minimal infrastructure requirements (**no NATS needed unless you enable KV-event-backed routing/replica sync**) - Minimal infrastructure requirements (NATS is initialized by default for KV events but can be disabled with `--no-kv-events`)
- Low-latency requirements - Low-latency requirements
**TCP Configuration Options:** **TCP Configuration Options:**
...@@ -172,7 +172,7 @@ DYN_REQUEST_PLANE=nats python -m dynamo.vllm --model Qwen/Qwen3-0.6B ...@@ -172,7 +172,7 @@ DYN_REQUEST_PLANE=nats python -m dynamo.vllm --model Qwen/Qwen3-0.6B
**When to use NATS:** **When to use NATS:**
- Production deployments with service discovery - Production deployments with service discovery
- Currently KV based routing require NATS. If you want to completely disable NATS, KV based routing won't be available - KV-aware routing with accurate cache state tracking (requires NATS for event transport). Note: approximate mode (`--no-kv-events`) provides KV routing without NATS but with reduced accuracy.
- Need for message replay and persistence features - Need for message replay and persistence features
Limitations: Limitations:
...@@ -301,6 +301,6 @@ curl http://localhost:8000/v1/chat/completions \ ...@@ -301,6 +301,6 @@ curl http://localhost:8000/v1/chat/completions \
### Resource Usage ### Resource Usage
- **TCP**: Minimal infrastructure (no additional services required) - **TCP**: Minimal infrastructure (NATS required only if using KV events, can disable with `--no-kv-events`)
- **HTTP**: Minimal infrastructure (no additional services required) - **HTTP**: Minimal infrastructure (NATS required only if using KV events, can disable with `--no-kv-events`)
- **NATS**: Requires running NATS server (additional memory/CPU) - **NATS**: Requires running NATS server (additional memory/CPU)
...@@ -213,6 +213,7 @@ The router uses KV events from workers by default to maintain an accurate global ...@@ -213,6 +213,7 @@ The router uses KV events from workers by default to maintain an accurate global
- Router predicts cache state based on routing decisions with TTL-based expiration and pruning - Router predicts cache state based on routing decisions with TTL-based expiration and pruning
- Tracks blocks from recent requests with configurable time-to-live - Tracks blocks from recent requests with configurable time-to-live
- Reduces overhead at the cost of routing accuracy - Reduces overhead at the cost of routing accuracy
- **NATS is not needed** - suitable for simpler deployments without NATS infrastructure
- Suitable for testing or when event processing becomes a bottleneck - Suitable for testing or when event processing becomes a bottleneck
## Tuning Guidelines ## Tuning Guidelines
......
...@@ -49,10 +49,15 @@ The main KV-aware routing arguments: ...@@ -49,10 +49,15 @@ The main KV-aware routing arguments:
> >
> **Request plane is independent of KV event transport.** > **Request plane is independent of KV event transport.**
> `DYN_REQUEST_PLANE` controls how **requests** are sent (TCP/HTTP/NATS), but KV-aware routing still uses **NATS** for KV events in both JetStream and NATS Core + Local Indexer modes. > `DYN_REQUEST_PLANE` controls how **requests** are sent (TCP/HTTP/NATS), but KV-aware routing still uses **NATS** for KV events in both JetStream and NATS Core + Local Indexer modes.
> If you run with `DYN_REQUEST_PLANE=tcp` (or `http`) and KV events enabled (default), you must also configure NATS, e.g. `NATS_SERVER=nats://...`. > When KV events are enabled (default), NATS is automatically initialized. You can optionally set `NATS_SERVER=nats://...` to specify a custom NATS server; otherwise, it defaults to `localhost:4222`.
> Only `--no-kv-events` removes the NATS requirement. > Use `--no-kv-events` to disable KV events and remove the NATS requirement entirely (with request plane being `tcp` or `http`).
> >
> When `--kv-overlap-score-weight` is set to 0, no KvIndexer is created and prefix matching is disabled (pure load balancing). When `--no-kv-events` is set, a KvIndexer is still created but no event subscriber is launched to consume KV events from workers. Instead, the router predicts cache state based on its own routing decisions with TTL-based expiration and pruning. In both cases, it's recommended to disable your backend workers from publishing events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases. > When `--kv-overlap-score-weight` is set to 0, no KvIndexer is created and prefix matching is disabled (pure load balancing). When `--no-kv-events` is set, a KvIndexer is still created but no event subscriber is launched to consume KV events from workers. Instead, the router predicts cache state based on its own routing decisions with TTL-based expiration and pruning.
>
> **Backend Configuration:** When using `--no-kv-events`, configure your backend workers to disable KV event publishing:
> - **vLLM**: Use `--kv-events-config '{"enable_kv_cache_events": false}'`
> - **SGLang**: Do not use `--kv-events-config`
> - **TRT-LLM**: Do not use `--publish-events-and-metrics`
> >
> The cli args `--router-ttl`, `--router-max-tree-size`, and `--router-prune-target-ratio` control local cache management when the router operates without receiving events from workers. When KV events are enabled (default), the router relies on worker-side eviction events and these parameters are ignored. > The cli args `--router-ttl`, `--router-max-tree-size`, and `--router-prune-target-ratio` control local cache management when the router operates without receiving events from workers. When KV events are enabled (default), the router relies on worker-side eviction events and these parameters are ignored.
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Common configuration
MODEL="Qwen/Qwen3-0.6B"
BLOCK_SIZE=64
# run frontend with KV router (--router-mode kv) in approximate mode (--no-kv-events)
python -m dynamo.frontend \
--router-mode kv \
--no-kv-events &
# run workers
# --enforce-eager is added for quick deployment. for production use, need to remove this flag
#
# If multiple workers are launched, they must not share the same system/metrics port.
# Use DYN_SYSTEM_PORT{1,2} so tests/launchers can provide a simple numbered port set.
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--kv-events-config '{"enable_kv_cache_events": false}' &
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT2:-8082} \
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 \
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--kv-events-config '{"enable_kv_cache_events": false}'
...@@ -534,7 +534,13 @@ enum ModelInput { ...@@ -534,7 +534,13 @@ enum ModelInput {
#[pymethods] #[pymethods]
impl DistributedRuntime { impl DistributedRuntime {
#[new] #[new]
fn new(event_loop: PyObject, store_kv: String, request_plane: String) -> PyResult<Self> { #[pyo3(signature = (event_loop, store_kv, request_plane, enable_nats=None))]
fn new(
event_loop: PyObject,
store_kv: String,
request_plane: String,
enable_nats: Option<bool>,
) -> PyResult<Self> {
let selected_kv_store: kv::Selector = store_kv.parse().map_err(to_pyerr)?; let selected_kv_store: kv::Selector = store_kv.parse().map_err(to_pyerr)?;
let request_plane: RequestPlaneMode = request_plane.parse().map_err(to_pyerr)?; let request_plane: RequestPlaneMode = request_plane.parse().map_err(to_pyerr)?;
...@@ -566,17 +572,19 @@ impl DistributedRuntime { ...@@ -566,17 +572,19 @@ impl DistributedRuntime {
}); });
} }
// NATS is used for more than just the NATS request-plane:
// - KV router events (JetStream or NATS core + local indexer)
// - inter-router replica sync (NATS core)
//
// NATS initialization logic:
// 1. If request_plane is NATS, always enable NATS
// 2. Otherwise, use enable_nats parameter (defaults to true for backward compat)
// Pass false to disable NATS (e.g., for approximate KV routing mode)
let enable_nats = enable_nats.unwrap_or(true); // Default to true
let runtime_config = DistributedConfig { let runtime_config = DistributedConfig {
store_backend: selected_kv_store, store_backend: selected_kv_store,
// NATS is used for more than just the NATS request-plane: nats_config: if request_plane.is_nats() || enable_nats {
// - KV router events (JetStream or NATS core + local indexer)
// - inter-router replica sync (NATS core)
//
// If a NATS server is configured via env, enable the client regardless of request plane.
nats_config: if request_plane.is_nats()
|| std::env::var(dynamo_runtime::config::environment_names::nats::NATS_SERVER)
.is_ok()
{
Some(dynamo_runtime::transports::nats::ClientOptions::default()) Some(dynamo_runtime::transports::nats::ClientOptions::default())
} else { } else {
None None
......
...@@ -38,7 +38,25 @@ class DistributedRuntime: ...@@ -38,7 +38,25 @@ class DistributedRuntime:
The runtime object for dynamo applications The runtime object for dynamo applications
""" """
... def __new__(
cls,
event_loop: Any,
store_kv: str,
request_plane: str,
enable_nats: Optional[bool] = None,
) -> "DistributedRuntime":
"""
Create a new DistributedRuntime.
Args:
event_loop: The asyncio event loop
store_kv: Key-value store backend ("etcd", "file", or "mem")
request_plane: Request plane transport ("tcp", "http", or "nats")
enable_nats: Whether to enable NATS for KV events. Defaults to True.
If request_plane is "nats", NATS is always enabled.
Pass False to disable NATS initialization (e.g., for approximate routing).
"""
...
def namespace(self, name: str) -> Namespace: def namespace(self, name: str) -> Namespace:
""" """
......
...@@ -21,13 +21,22 @@ from dynamo._core import Namespace as Namespace ...@@ -21,13 +21,22 @@ from dynamo._core import Namespace as Namespace
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
def dynamo_worker(): def dynamo_worker(enable_nats: bool = True):
"""
Decorator that creates a DistributedRuntime and passes it to the worker function.
Args:
enable_nats: Whether to enable NATS for KV events. Defaults to True.
If request_plane is "nats", NATS is always enabled.
Pass False (via --no-kv-events flag) to disable NATS initialization.
"""
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
request_plane = os.environ.get("DYN_REQUEST_PLANE", "tcp") request_plane = os.environ.get("DYN_REQUEST_PLANE", "tcp")
runtime = DistributedRuntime(loop, "etcd", request_plane) runtime = DistributedRuntime(loop, "etcd", request_plane, enable_nats)
await func(runtime, *args, **kwargs) await func(runtime, *args, **kwargs)
......
...@@ -397,13 +397,18 @@ impl DistributedRuntime { ...@@ -397,13 +397,18 @@ impl DistributedRuntime {
/// TODO: This is a temporary KV router measure for component/component.rs EventPublisher impl for /// TODO: This is a temporary KV router measure for component/component.rs EventPublisher impl for
/// Component, to allow it to publish to NATS. KV Router is the only user. /// Component, to allow it to publish to NATS. KV Router is the only user.
///
/// When NATS is not available (e.g., running in approximate mode with --no-kv-events),
/// this function returns Ok(()) silently since publishing is optional in that mode.
pub async fn kv_router_nats_publish( pub async fn kv_router_nats_publish(
&self, &self,
subject: String, subject: String,
payload: bytes::Bytes, payload: bytes::Bytes,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let Some(nats_client) = self.nats_client.as_ref() else { let Some(nats_client) = self.nats_client.as_ref() else {
anyhow::bail!("KV router's EventPublisher requires NATS"); // NATS not available - this is expected in approximate mode (--no-kv-events)
tracing::trace!("Skipping NATS publish (NATS not configured): {}", subject);
return Ok(());
}; };
Ok(nats_client.client().publish(subject, payload).await?) Ok(nats_client.client().publish(subject, payload).await?)
} }
......
...@@ -159,6 +159,10 @@ def run_serve_deployment( ...@@ -159,6 +159,10 @@ def run_serve_deployment(
) )
server_process.check_response(payload, response) server_process.check_response(payload, response)
# Call final_validation if the payload has one (e.g., CachedTokensChatPayload)
if hasattr(payload, "final_validation"):
payload.final_validation()
def params_with_model_mark(configs: Mapping[str, EngineConfig]): def params_with_model_mark(configs: Mapping[str, EngineConfig]):
"""Return pytest params for a config dict, adding a model marker per param. """Return pytest params for a config dict, adding a model marker per param.
......
...@@ -21,6 +21,7 @@ from tests.serve.lora_utils import MinioLoraConfig ...@@ -21,6 +21,7 @@ from tests.serve.lora_utils import MinioLoraConfig
from tests.utils.constants import DefaultPort from tests.utils.constants import DefaultPort
from tests.utils.engine_process import EngineConfig from tests.utils.engine_process import EngineConfig
from tests.utils.payload_builder import ( from tests.utils.payload_builder import (
cached_tokens_chat_payload,
chat_payload, chat_payload,
chat_payload_default, chat_payload_default,
chat_payload_with_logprobs, chat_payload_with_logprobs,
...@@ -190,6 +191,37 @@ vllm_configs = { ...@@ -190,6 +191,37 @@ vllm_configs = {
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info", "DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
}, },
), ),
"agg-router-approx": VLLMConfig(
name="agg-router-approx",
directory=vllm_dir,
script_name="agg_router_approx.sh",
marks=[pytest.mark.gpu_2, pytest.mark.post_merge],
model="Qwen/Qwen3-0.6B",
request_payloads=[
# Test approximate KV routing (--no-kv-events mode)
# Repeated requests should show cache-aware routing in logs
chat_payload_default(
repeat_count=3,
expected_log=[
# Verify scheduler is selecting workers with cache awareness
r"Selected worker: worker_id=\d+ dp_rank=.*?, logit: ",
# After first request, should see cached blocks being tracked
r"with \d+ cached blocks",
],
),
# Also test with cached tokens payload to verify usage field
cached_tokens_chat_payload(
repeat_count=3,
expected_log=[
# Verify routing decision shows cache hits
r"with \d+ cached blocks",
],
),
],
env={
"DYN_LOG": "dynamo_llm::kv_router::scheduler=info",
},
),
"disaggregated": VLLMConfig( "disaggregated": VLLMConfig(
name="disaggregated", name="disaggregated",
directory=vllm_dir, directory=vllm_dir,
......
...@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union ...@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union
from tests.utils.client import send_request from tests.utils.client import send_request
from tests.utils.constants import DefaultPort from tests.utils.constants import DefaultPort
from tests.utils.payloads import ( from tests.utils.payloads import (
CachedTokensChatPayload,
ChatPayload, ChatPayload,
ChatPayloadWithLogprobs, ChatPayloadWithLogprobs,
CompletionPayload, CompletionPayload,
...@@ -17,6 +18,17 @@ from tests.utils.payloads import ( ...@@ -17,6 +18,17 @@ from tests.utils.payloads import (
# Common default text prompt used across tests # Common default text prompt used across tests
TEXT_PROMPT = "Tell me a knock knock joke about AI." TEXT_PROMPT = "Tell me a knock knock joke about AI."
# Longer prompt for prefix caching tests - needs to be > 64 tokens (typical block size)
# to ensure at least one full block gets cached
LONG_PROMPT_FOR_CACHING = """In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, \
lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the \
shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled \
curiosity and courage, who has stumbled upon an ancient map hinting at the city's location. The map suggests that \
Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey \
will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. \
Your Task: Character Background: Develop a detailed background for your character. Describe their motivations \
for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends."""
def chat_payload_default( def chat_payload_default(
repeat_count: int = 3, repeat_count: int = 3,
...@@ -46,6 +58,54 @@ def chat_payload_default( ...@@ -46,6 +58,54 @@ def chat_payload_default(
) )
def cached_tokens_chat_payload(
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 100,
temperature: float = 0.0,
min_cached_tokens: int = 64,
) -> CachedTokensChatPayload:
"""Create a chat payload that validates cached tokens in usage field.
This is useful for testing KV router cache-aware routing where repeated
identical prompts should result in cached tokens being reported.
Uses a longer prompt (~196 tokens) to ensure at least one full block (64 tokens)
gets cached. vLLM only caches complete blocks, so short prompts won't trigger
the cached_tokens field in the response.
Args:
repeat_count: Number of times to repeat the request (>1 needed to see caching)
expected_response: List of expected strings in response
expected_log: List of expected log patterns
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
min_cached_tokens: Minimum cached tokens expected after first request (default: 64, one block)
Returns:
CachedTokensChatPayload configured for testing prefix caching
"""
return CachedTokensChatPayload(
body={
"messages": [
{
"role": "user",
"content": LONG_PROMPT_FOR_CACHING,
}
],
"max_tokens": max_tokens,
"temperature": temperature,
"stream": False,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response
or ["Aeloria", "Eldoria", "explorer", "ancient", "character", "background"],
min_cached_tokens=min_cached_tokens,
)
def completion_payload_default( def completion_payload_default(
repeat_count: int = 3, repeat_count: int = 3,
expected_response: Optional[List[str]] = None, expected_response: Optional[List[str]] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment