Unverified Commit 14eceb43 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: rename KvPushRouter to KvRouter in python + more bindings removal (#6238)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 52271760
......@@ -26,7 +26,7 @@ from vllm.v1.engine.output_processor import OutputProcessor, OutputProcessorOutp
from dynamo.frontend.frontend_args import FrontendConfig
from dynamo.llm import (
KvPushRouter,
KvRouter,
ModelCardInstanceId,
ModelDeploymentCard,
PythonAsyncEngine,
......@@ -77,7 +77,7 @@ class VllmProcessor:
self,
tokenizer: TokenizerLike,
input_processor: InputProcessor,
router, # Client or KvPushRouter
router, # Client or KvRouter
output_processor: OutputProcessor,
tool_parser_class: type[ToolParser] | None,
reasoning_parser_class: type[ReasoningParser] | None,
......@@ -85,7 +85,7 @@ class VllmProcessor:
self.tokenizer = tokenizer
self.input_processor = input_processor
self.router = router
self.is_kv_router = isinstance(router, KvPushRouter)
self.is_kv_router = isinstance(router, KvRouter)
self.output_processor = output_processor
self.tool_parser_class = tool_parser_class
self.reasoning_parser_class = reasoning_parser_class
......@@ -445,7 +445,7 @@ class EngineFactory:
)
if self.router_config.router_mode == RouterMode.KV:
router = KvPushRouter(
router = KvRouter(
endpoint=generate_endpoint,
block_size=self.config.kv_cache_block_size or 16,
kv_router_config=self.router_config.kv_router_config,
......
......@@ -20,7 +20,7 @@ from typing import Optional
import uvloop
from dynamo.llm import KvPushRouter, KvRouterConfig
from dynamo.llm import KvRouter, KvRouterConfig
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -42,7 +42,7 @@ class StandaloneRouterHandler:
self.worker_endpoint_path = worker_endpoint_path
self.block_size = block_size
self.kv_router_config = kv_router_config
self.kv_push_router: Optional[KvPushRouter] = None
self.kv_router: Optional[KvRouter] = None
self.worker_client: Optional[Client] = None
async def initialize(self):
......@@ -66,15 +66,14 @@ class StandaloneRouterHandler:
self.worker_client = await worker_endpoint.client()
# Create KvPushRouter with specified configuration
self.kv_push_router = KvPushRouter(
self.kv_router = KvRouter(
endpoint=worker_endpoint,
block_size=self.block_size,
kv_router_config=self.kv_router_config,
)
except Exception as e:
logger.error(f"Failed to initialize KvPushRouter: {e}")
logger.error(f"Failed to initialize KvRouter: {e}")
raise
async def generate(self, request):
......@@ -85,11 +84,11 @@ class StandaloneRouterHandler:
Wraps the request into PreprocessedRequest format and wraps worker responses
into LLMEngineOutput format.
"""
if self.kv_push_router is None:
logger.error("KvPushRouter not initialized - cannot process request")
if self.kv_router is None:
logger.error("KvRouter not initialized - cannot process request")
raise RuntimeError("Router not initialized")
# Wrap incoming request into PreprocessedRequest format for KvPushRouter
# Wrap incoming request into PreprocessedRequest format for KvRouter
# The request should already have most fields, but we ensure it has the structure
# Build routing hints from request (supports both nested routing object and legacy dp_rank)
routing = request.get("routing")
......@@ -112,8 +111,7 @@ class StandaloneRouterHandler:
"extra_args": request.get("extra_args"),
}
# Route and process through KvPushRouter
async for worker_output in await self.kv_push_router.generate_from_request(
async for worker_output in await self.kv_router.generate_from_request(
preprocessed_request
):
# Wrap worker output into LLMEngineOutput format
......@@ -142,11 +140,11 @@ class StandaloneRouterHandler:
overlap, but does NOT actually route the request or update router states.
It's useful for debugging, monitoring, or implementing custom routing logic.
"""
if self.kv_push_router is None:
logger.error("KvPushRouter not initialized - cannot get best worker")
if self.kv_router is None:
logger.error("KvRouter not initialized - cannot get best worker")
raise RuntimeError("Router not initialized")
(worker_id, _dp_rank, _overlap_blocks) = await self.kv_push_router.best_worker(
(worker_id, _dp_rank, _overlap_blocks) = await self.kv_router.best_worker(
token_ids, router_config_override
)
......@@ -220,6 +218,14 @@ def parse_args():
help="KV Router: Reset router state on startup, purging stream and object store. By default, states are persisted. WARNING: This can affect existing router replicas (default: False)",
)
parser.add_argument(
"--durable-kv-events",
action="store_true",
dest="durable_kv_events",
default=False,
help="KV Router: Enable durable KV events using NATS JetStream instead of NATS Core. By default, the router uses the generic event plane (NATS Core or ZMQ) with local_indexer mode. Use this flag when you need durability and multi-replica consistency. Requires NATS with JetStream enabled.",
)
parser.add_argument(
"--no-track-active-blocks",
action="store_false",
......@@ -228,6 +234,14 @@ def parse_args():
help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing (default: True)",
)
parser.add_argument(
"--no-assume-kv-reuse",
action="store_false",
dest="router_assume_kv_reuse",
default=True,
help="KV Router: When tracking active blocks, do not assume KV cache reuse (generate random hashes instead of computing actual block hashes). Useful when KV cache reuse is not expected. By default, KV cache reuse is assumed.",
)
parser.add_argument(
"--track-output-blocks",
action="store_true",
......@@ -288,10 +302,12 @@ async def worker(runtime: DistributedRuntime):
f"overlap_score_weight={args.kv_overlap_score_weight}, "
f"router_temperature={args.router_temperature}, "
f"use_kv_events={args.use_kv_events}, "
f"durable_kv_events={args.durable_kv_events}, "
f"router_replica_sync={args.router_replica_sync}, "
f"router_reset_states={args.router_reset_states}, "
f"router_track_active_blocks={args.router_track_active_blocks}, "
f"router_track_output_blocks={args.router_track_output_blocks}, "
f"router_assume_kv_reuse={args.router_assume_kv_reuse}, "
f"router_ttl_secs={args.router_ttl_secs}, "
f"router_max_tree_size={args.router_max_tree_size}, "
f"router_prune_target_ratio={args.router_prune_target_ratio}"
......@@ -302,11 +318,13 @@ async def worker(runtime: DistributedRuntime):
overlap_score_weight=args.kv_overlap_score_weight,
router_temperature=args.router_temperature,
use_kv_events=args.use_kv_events,
durable_kv_events=args.durable_kv_events,
router_replica_sync=args.router_replica_sync,
router_snapshot_threshold=args.router_snapshot_threshold,
router_reset_states=args.router_reset_states,
router_track_active_blocks=args.router_track_active_blocks,
router_track_output_blocks=args.router_track_output_blocks,
router_assume_kv_reuse=args.router_assume_kv_reuse,
router_snapshot_threshold=args.router_snapshot_threshold,
router_reset_states=args.router_reset_states,
router_ttl_secs=args.router_ttl_secs,
router_max_tree_size=args.router_max_tree_size,
router_prune_target_ratio=args.router_prune_target_ratio,
......
......@@ -20,11 +20,7 @@ from dynamo.common.utils.prometheus import (
LLMBackendMetrics,
register_engine_metrics_callback,
)
from dynamo.llm import (
KvEventPublisher,
WorkerMetricsPublisher,
ZmqKvEventPublisherConfig,
)
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
from dynamo.runtime import Component, Endpoint
from dynamo.sglang.args import Config
......@@ -256,19 +252,17 @@ class DynamoSglangPublisher:
zmq_ep = format_zmq_endpoint(zmq_ep, local_ip)
zmq_config = ZmqKvEventPublisherConfig(
worker_id=self.generate_endpoint.connection_id(),
kv_block_size=self.server_args.page_size,
zmq_endpoint=zmq_ep,
enable_local_indexer=self.dynamo_args.enable_local_indexer,
dp_rank=dp_rank,
)
logging.info(
f"Setting up ZMQ kv event subscriber for dp_rank={dp_rank} "
f"(connecting to {zmq_ep})"
)
publisher = KvEventPublisher(
component=self.component, zmq_config=zmq_config
component=self.component,
kv_block_size=self.server_args.page_size,
zmq_endpoint=zmq_ep,
zmq_topic="",
enable_local_indexer=self.dynamo_args.enable_local_indexer,
dp_rank=dp_rank,
)
self.kv_publishers.append(publisher)
......
......@@ -42,7 +42,6 @@ from dynamo.llm import (
ModelInput,
ModelRuntimeConfig,
ModelType,
ZmqKvEventPublisherConfig,
register_llm,
)
from dynamo.runtime import DistributedRuntime
......@@ -476,14 +475,11 @@ async def init_llm_worker(
consolidator_publisher = None
if consolidator_output_endpoint:
# Use the connect endpoint directly (already provided by get_consolidator_endpoints)
consolidator_config = ZmqKvEventPublisherConfig(
worker_id=int(endpoint.connection_id()),
consolidator_publisher = KvEventPublisher(
component,
kv_block_size=config.kv_block_size,
zmq_endpoint=consolidator_output_connect_endpoint,
zmq_topic="", # Empty topic = all topics
)
consolidator_publisher = KvEventPublisher(
component, zmq_config=consolidator_config
zmq_topic="",
)
logging.info(
f"Created worker-side publisher for consolidated events: "
......
......@@ -28,7 +28,6 @@ from dynamo.llm import (
ModelInput,
ModelRuntimeConfig,
ModelType,
ZmqKvEventPublisherConfig,
fetch_llm,
register_llm,
)
......@@ -341,14 +340,14 @@ def setup_kv_event_publisher(
f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
)
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.connection_id(),
kv_publisher = KvEventPublisher(
component=component,
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
zmq_topic="",
enable_local_indexer=config.enable_local_indexer,
dp_rank=dp_rank,
)
kv_publisher = KvEventPublisher(component=component, zmq_config=zmq_config)
kv_publishers.append(kv_publisher)
logger.info(
......
......@@ -10,23 +10,23 @@ For quick start instructions, see the [Router README](README.md). This document
## Table of Contents
- [Using KvPushRouter Python API](#using-kvpushrouter-python-api)
- [Using KvRouter Python API](#using-kvrouter-python-api)
- [K8s Examples](#k8s-examples)
- [Routing Patterns](#routing-patterns)
- [Custom Routing Example: Minimizing TTFT](#custom-routing-example-minimizing-ttft)
- [KV Event Publishing for Custom Engines](#kv-event-publishing-for-custom-engines)
- [Global Router (Hierarchical Routing)](#global-router-hierarchical-routing)
## Using KvPushRouter Python API
## Using KvRouter Python API
Instead of launching the KV Router via command line, you can create a `KvPushRouter` object directly in Python. This allows per-request routing configuration overrides.
Instead of launching the KV Router via command line, you can create a `KvRouter` object directly in Python. This allows per-request routing configuration overrides.
>[!Warning]
> **Multiple Routers in Same Process**: If you need to run multiple `KvPushRouter` instances for fault tolerance or load distribution, you must launch them in **separate processes** (e.g., using `python -m dynamo.frontend` with different ports). Creating multiple `KvPushRouter` objects in the same Python process is not supported - they share the same cancellation token from the component's primary lease, so dropping one router will cancel all routers in that process. For in-process routing, use a single `KvPushRouter` instance.
> **Multiple Routers in Same Process**: If you need to run multiple `KvRouter` instances for fault tolerance or load distribution, you must launch them in **separate processes** (e.g., using `python -m dynamo.frontend` with different ports). Creating multiple `KvRouter` objects in the same Python process is not supported - they share the same cancellation token from the component's primary lease, so dropping one router will cancel all routers in that process. For in-process routing, use a single `KvRouter` instance.
### Methods
The `KvPushRouter` provides the following methods:
The `KvRouter` provides the following methods:
- **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management.
......@@ -53,7 +53,7 @@ python -m dynamo.vllm --model meta-llama/Llama-2-7b-hf
```python
import asyncio
from dynamollm import DistributedRuntime, KvPushRouter, KvRouterConfig
from dynamollm import DistributedRuntime, KvRouter, KvRouterConfig
async def main():
# Get runtime and create endpoint
......@@ -64,7 +64,7 @@ async def main():
# Create KV router
kv_router_config = KvRouterConfig()
router = KvPushRouter(
router = KvRouter(
endpoint=endpoint,
block_size=16,
kv_router_config=kv_router_config
......@@ -163,7 +163,7 @@ extraPodSpec:
## Routing Patterns
The `KvPushRouter` supports multiple usage patterns depending on your control requirements:
The `KvRouter` supports multiple usage patterns depending on your control requirements:
### 1. Automatic Routing (Recommended)
Call `generate()` directly and let the router handle everything:
......@@ -222,7 +222,7 @@ Here's an example of using `get_potential_loads()` to implement custom routing t
```python
import asyncio
from dynamo.llm import DistributedRuntime, KvPushRouter, KvRouterConfig
from dynamo.llm import DistributedRuntime, KvRouter, KvRouterConfig
async def minimize_ttft_routing():
# Setup router
......@@ -231,7 +231,7 @@ async def minimize_ttft_routing():
component = namespace.component("backend")
endpoint = component.endpoint("generate")
router = KvPushRouter(
router = KvRouter(
endpoint=endpoint,
block_size=16,
kv_router_config=KvRouterConfig()
......@@ -447,25 +447,19 @@ flowchart LR
#### Part 1: ZMQ Subscriber (Dynamo Bindings)
If your engine already publishes to ZMQ, use `ZmqKvEventPublisher` to subscribe and forward to NATS:
If your engine already publishes to ZMQ, use `KvEventPublisher` with `zmq_endpoint` to subscribe and forward to NATS:
```python
from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig
from dynamo.llm import KvEventPublisher
# Configure the ZMQ subscriber
config = ZmqKvEventPublisherConfig(
worker_id=endpoint.connection_id(),
# Create publisher - it automatically subscribes to ZMQ and forwards to NATS
kv_publisher = KvEventPublisher(
component=component,
kv_block_size=block_size,
zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes
zmq_topic="", # Subscribe to all topics
enable_local_indexer=False,
)
# Create publisher - it automatically subscribes to ZMQ and forwards to NATS
kv_publisher = ZmqKvEventPublisher(
component=component,
config=config,
)
```
#### Part 2: ZMQ Publisher (Pure Python)
......
......@@ -185,23 +185,17 @@ flowchart LR
### Part 1: ZMQ Subscriber (Dynamo Bindings)
If your engine already publishes to ZMQ, use `KvEventPublisher` with a `ZmqKvEventPublisherConfig` to subscribe and forward to NATS:
If your engine already publishes to ZMQ, use `KvEventPublisher` with `zmq_endpoint` (and optional `zmq_topic`) to subscribe and forward to NATS:
```python
from dynamo.llm import KvEventPublisher, ZmqKvEventPublisherConfig
# Configure the ZMQ subscriber
config = ZmqKvEventPublisherConfig(
worker_id=endpoint.connection_id(),
kv_block_size=block_size,
zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes
zmq_topic="", # Subscribe to all topics
)
from dynamo.llm import KvEventPublisher
# Create publisher - it automatically subscribes to ZMQ and forwards to NATS
kv_publisher = KvEventPublisher(
component=component,
zmq_config=config,
kv_block_size=block_size,
zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes
zmq_topic="", # Subscribe to all topics
)
```
......
......@@ -27,7 +27,7 @@ from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from transformers import AutoProcessor
from worker import TrtllmWorkers
from dynamo._core import compute_block_hash_for_seq_py
from dynamo._core import compute_block_hash_for_seq
logger = logging.getLogger(__name__)
......@@ -570,7 +570,7 @@ class ServiceAPI:
processed.image_offsets_list,
)
logger.debug(f"block_mm_infos: {block_mm_infos}")
local_hashes = compute_block_hash_for_seq_py(
local_hashes = compute_block_hash_for_seq(
processed.tokens, self.init_params.block_size, block_mm_infos
)
......
......@@ -19,7 +19,7 @@ from dataclasses import dataclass
import httpx
from dynamo.llm import compute_block_hash_for_seq_py
from dynamo.llm import compute_block_hash_for_seq
# Sample test images from COCO dataset
TEST_IMAGE_1 = "http://images.cocodataset.org/test2017/000000155781.jpg"
......@@ -340,7 +340,7 @@ class KvRouterTests:
def _test_mm_hash_computation(self):
"""
Test: Verify that compute_block_hash_for_seq_py produces different hashes
Test: Verify that compute_block_hash_for_seq produces different hashes
for same tokens with different mm_hash values.
"""
print("\n[MM-1] MM Hash Computation Test")
......@@ -351,15 +351,15 @@ class KvRouterTests:
block_size = 32
# Hash without MM info
hash_no_mm = compute_block_hash_for_seq_py(tokens, block_size)
hash_no_mm = compute_block_hash_for_seq(tokens, block_size)
# Hash with MM info (simulated mm_hash)
mm_info_1 = {"mm_objects": [{"mm_hash": 0xDEADBEEF, "offsets": [[0, 32]]}]}
hash_with_mm1 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info_1])
hash_with_mm1 = compute_block_hash_for_seq(tokens, block_size, [mm_info_1])
# Hash with different MM info
mm_info_2 = {"mm_objects": [{"mm_hash": 0xCAFEBABE, "offsets": [[0, 32]]}]}
hash_with_mm2 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info_2])
hash_with_mm2 = compute_block_hash_for_seq(tokens, block_size, [mm_info_2])
self.log(f"Hash without MM: {hash_no_mm}")
self.log(f"Hash with MM 1: {hash_with_mm1}")
......@@ -410,7 +410,7 @@ class KvRouterTests:
mm_info_a = {
"mm_objects": [{"mm_hash": 0x1111111111111111, "offsets": [[0, 64]]}]
}
hashes_a = compute_block_hash_for_seq_py(
hashes_a = compute_block_hash_for_seq(
tokens, block_size, [mm_info_a, mm_info_a]
)
......@@ -418,7 +418,7 @@ class KvRouterTests:
mm_info_b = {
"mm_objects": [{"mm_hash": 0x2222222222222222, "offsets": [[0, 64]]}]
}
hashes_b = compute_block_hash_for_seq_py(
hashes_b = compute_block_hash_for_seq(
tokens, block_size, [mm_info_b, mm_info_b]
)
......@@ -459,9 +459,9 @@ class KvRouterTests:
mm_info = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 32]]}]}
# Compute hash multiple times
hash1 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info])
hash2 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info])
hash3 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info])
hash1 = compute_block_hash_for_seq(tokens, block_size, [mm_info])
hash2 = compute_block_hash_for_seq(tokens, block_size, [mm_info])
hash3 = compute_block_hash_for_seq(tokens, block_size, [mm_info])
self.log(f"Hash 1: {hash1}")
self.log(f"Hash 2: {hash2}")
......@@ -499,19 +499,19 @@ class KvRouterTests:
# Image covers first block only
mm_info_first = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 32]]}]}
hash_first = compute_block_hash_for_seq_py(
hash_first = compute_block_hash_for_seq(
tokens, block_size, [mm_info_first, None]
)
# Image covers second block only
mm_info_second = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[32, 64]]}]}
hash_second = compute_block_hash_for_seq_py(
hash_second = compute_block_hash_for_seq(
tokens, block_size, [None, mm_info_second]
)
# Image covers both blocks
mm_info_both = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 64]]}]}
hash_both = compute_block_hash_for_seq_py(
hash_both = compute_block_hash_for_seq(
tokens, block_size, [mm_info_both, mm_info_both]
)
......@@ -555,12 +555,12 @@ class KvRouterTests:
# MM info only applies to middle block
mm_info = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[32, 64]]}]}
hashes_with_mm = compute_block_hash_for_seq_py(
hashes_with_mm = compute_block_hash_for_seq(
tokens, block_size, [None, mm_info, None]
)
# No MM info
hashes_without_mm = compute_block_hash_for_seq_py(tokens, block_size, None)
hashes_without_mm = compute_block_hash_for_seq(tokens, block_size, None)
self.log(f"Hashes with MM: {hashes_with_mm}")
self.log(f"Hashes without MM: {hashes_without_mm}")
......
......@@ -23,7 +23,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.engine.async_llm import AsyncLLM
import dynamo.nixl_connect as connect
from dynamo.llm import KvEventPublisher, ZmqKvEventPublisherConfig
from dynamo.llm import KvEventPublisher
from dynamo.runtime import Component, DistributedRuntime, Endpoint, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -163,12 +163,11 @@ class VllmBaseWorker:
data_parallel_rank=self.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=endpoint.connection_id(),
self.kv_publisher = KvEventPublisher(
component=component,
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
self.kv_publisher = KvEventPublisher(component=component, zmq_config=zmq_config)
logger.info(f"Reading Events from {zmq_endpoint}")
......
......@@ -170,15 +170,13 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<llm::kv::RadixTree>()?;
m.add_class::<llm::kv::ZmqKvEventListener>()?;
m.add_class::<llm::kv::ZmqKvEventPublisherConfig>()?;
m.add_class::<llm::lora::LoRADownloader>()?;
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<context::Context>()?;
m.add_class::<ModelType>()?;
m.add_class::<ModelInput>()?;
m.add_class::<llm::kv::KvPushRouter>()?;
m.add_class::<llm::kv::KvPushRouterStream>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<RouterMode>()?;
m.add_class::<kserve_grpc::KserveGrpcService>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
......@@ -987,10 +985,7 @@ impl Client {
_ => client.round_robin(request_ctx).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
annotated,
})
Ok(AsyncResponseStream::new(rx, annotated))
})
}
......@@ -1024,10 +1019,7 @@ impl Client {
_ => client.random(request_ctx).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
annotated,
})
Ok(AsyncResponseStream::new(rx, annotated))
})
}
......@@ -1068,10 +1060,7 @@ impl Client {
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
annotated,
})
Ok(AsyncResponseStream::new(rx, annotated))
})
}
}
......@@ -1106,11 +1095,23 @@ async fn process_stream(
}
#[pyclass]
struct AsyncResponseStream {
pub(crate) struct AsyncResponseStream {
rx: Arc<Mutex<tokio::sync::mpsc::Receiver<RsAnnotated<PyObject>>>>,
annotated: bool,
}
impl AsyncResponseStream {
pub(crate) fn new(
rx: tokio::sync::mpsc::Receiver<RsAnnotated<PyObject>>,
annotated: bool,
) -> Self {
Self {
rx: Arc::new(Mutex::new(rx)),
annotated,
}
}
}
#[pymethods]
impl AsyncResponseStream {
/// This method is required to implement the `AsyncIterator` protocol.
......
......@@ -12,8 +12,10 @@ use super::*;
use crate::Component;
use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
use rs::pipeline::{AsyncEngine, SingleIn};
use rs::protocols::annotated::Annotated as RsAnnotated;
use tracing;
use llm_rs::kv_router::KvPushRouter as RsKvPushRouter;
use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener};
use llm_rs::protocols::common::timing::RequestTracker;
......@@ -21,7 +23,7 @@ use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use serde_json::json;
#[pyfunction]
#[pyo3(signature = (tokens, kv_block_size, block_mm_infos=None))]
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None))]
pub fn compute_block_hash_for_seq_py(
_py: Python,
tokens: Vec<u32>,
......@@ -99,54 +101,6 @@ impl WorkerMetricsPublisher {
}
}
#[pyclass]
#[derive(Clone)]
pub struct ZmqKvEventPublisherConfig {
#[pyo3(get, set)]
pub worker_id: WorkerId,
#[pyo3(get, set)]
pub kv_block_size: usize,
#[pyo3(get, set)]
pub zmq_endpoint: String,
#[pyo3(get, set)]
pub zmq_topic: String,
#[pyo3(get, set)]
pub enable_local_indexer: bool, // whether the underlying KvEventPublisher publishes to
// both global and worker-local KvIndexers
#[pyo3(get, set)]
pub dp_rank: DpRank, // data parallel rank for this publisher
}
#[pymethods]
impl ZmqKvEventPublisherConfig {
#[new]
#[pyo3(signature = (
worker_id,
kv_block_size,
zmq_endpoint = "tcp://127.0.0.1:5557".to_string(),
zmq_topic = "".to_string(),
enable_local_indexer = true,
dp_rank = 0
))]
pub fn new(
worker_id: WorkerId,
kv_block_size: usize,
zmq_endpoint: String,
zmq_topic: String,
enable_local_indexer: bool,
dp_rank: DpRank,
) -> Self {
Self {
worker_id,
kv_block_size,
zmq_endpoint,
zmq_topic,
enable_local_indexer,
dp_rank,
}
}
}
/// A ZMQ-based key-value cache event listener that operates independently
/// of the dynamo runtime or event plane infrastructure.
#[pyclass]
......@@ -231,33 +185,22 @@ pub(crate) struct KvEventPublisher {
#[pymethods]
impl KvEventPublisher {
#[new]
#[pyo3(signature = (component, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_config=None))]
#[pyo3(signature = (component, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_endpoint=None, zmq_topic=None))]
fn new(
component: Component,
worker_id: WorkerId,
kv_block_size: usize,
dp_rank: DpRank,
enable_local_indexer: bool,
zmq_config: Option<ZmqKvEventPublisherConfig>,
zmq_endpoint: Option<String>,
zmq_topic: Option<String>,
) -> PyResult<Self> {
// worker_id is not used; connection_id is inferred from the component.
let _ = worker_id;
// When zmq_config is provided, use its fields for kv_block_size/dp_rank/enable_local_indexer
let (kv_block_size, dp_rank, enable_local_indexer, source_config) =
if let Some(ref cfg) = zmq_config {
(
cfg.kv_block_size,
cfg.dp_rank,
cfg.enable_local_indexer,
Some(KvEventSourceConfig::Zmq {
endpoint: cfg.zmq_endpoint.clone(),
topic: cfg.zmq_topic.clone(),
}),
)
} else {
(kv_block_size, dp_rank, enable_local_indexer, None)
};
let source_config = zmq_endpoint.map(|endpoint| KvEventSourceConfig::Zmq {
endpoint,
topic: zmq_topic.unwrap_or_default(),
});
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
......@@ -719,8 +662,8 @@ async fn create_kv_router_from_endpoint(
}
#[pyclass]
pub(crate) struct KvPushRouter {
inner: Arc<llm_rs::kv_router::KvPushRouter>,
pub(crate) struct KvRouter {
inner: Arc<RsKvPushRouter>,
}
/// Inject worker_id info from tracker into response's disaggregated_params.
......@@ -749,27 +692,25 @@ fn inject_worker_id_from_tracker(
}
// TODO: can this reuse the stream conversion method in Client bindings?
impl KvPushRouter {
impl KvRouter {
/// Helper method to process a request and create a Python async generator
fn process_request_to_stream<'p>(
py: Python<'p>,
inner: Arc<llm_rs::kv_router::KvPushRouter>,
inner: Arc<RsKvPushRouter>,
request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
tracker: Option<Arc<RequestTracker>>,
) -> PyResult<Bound<'p, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let single_in = SingleIn::new(request);
let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
let (tx, rx) = tokio::sync::mpsc::channel(100);
let (tx, rx) = tokio::sync::mpsc::channel::<RsAnnotated<PyObject>>(100);
// Spawn a task to process the stream
tokio::spawn(async move {
let mut stream = stream;
let mut first_item = true;
let mut first_token_gauges_observed = false;
while let Some(mut response) = stream.next().await {
// Inject worker_id into first response if tracker is available
if first_item {
first_item = false;
if let (Some(tracker), Some(data)) = (&tracker, &mut response.data) {
......@@ -777,7 +718,6 @@ impl KvPushRouter {
}
}
// Observe per-worker TTFT/ISL gauges on first response with actual tokens
if !first_token_gauges_observed {
let has_tokens = response
.data
......@@ -792,7 +732,6 @@ impl KvPushRouter {
}
}
// Convert LLMEngineOutput to PyObject
let py_response = Python::with_gil(|py| {
pythonize(py, &response.data)
.map(|obj| obj.unbind())
......@@ -801,8 +740,8 @@ impl KvPushRouter {
match py_response {
Ok(obj) => {
if tx.send(obj).await.is_err() {
break; // Receiver dropped
if tx.send(RsAnnotated::from_data(obj)).await.is_err() {
break;
}
}
Err(e) => {
......@@ -812,23 +751,19 @@ impl KvPushRouter {
}
}
// Observe per-worker ITL gauge at stream end
if let Some(ref tracker) = tracker {
tracker.observe_finish_gauges();
}
});
// Return a Python async generator wrapper
Ok(KvPushRouterStream {
rx: Arc::new(tokio::sync::Mutex::new(rx)),
})
Ok(crate::AsyncResponseStream::new(rx, false))
})
}
}
#[pymethods]
impl KvPushRouter {
/// Create a new KvPushRouter for KV-aware routing to workers.
impl KvRouter {
/// Create a new KvRouter for KV-aware routing to workers.
///
/// # Arguments
/// * `endpoint` - The endpoint to route requests to
......@@ -869,8 +804,7 @@ impl KvPushRouter {
)
.await?;
// Create KvPushRouter (kv_router is already Arc<KvRouter>)
let kv_push_router = llm_rs::kv_router::KvPushRouter::new(push_router, kv_router);
let kv_push_router = RsKvPushRouter::new(push_router, kv_router);
Ok(Self {
inner: Arc::new(kv_push_router),
......@@ -1078,32 +1012,3 @@ impl KvPushRouter {
})
}
}
// Python async generator wrapper for the stream
#[pyclass]
pub(crate) struct KvPushRouterStream {
rx: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<PyObject>>>,
}
#[pymethods]
impl KvPushRouterStream {
#[pyo3(name = "__aiter__")]
fn aiter(slf: Bound<'_, Self>) -> PyResult<Py<PyAny>> {
Ok(slf.clone().into_any().unbind())
}
#[pyo3(name = "__anext__")]
fn anext<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let rx = self.rx.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut rx = rx.lock().await;
match rx.recv().await {
Some(obj) => Ok(obj),
None => Err(pyo3::exceptions::PyStopAsyncIteration::new_err(
"Stream exhausted",
)),
}
})
}
}
......@@ -265,7 +265,7 @@ class ModelCardInstanceId:
...
def compute_block_hash_for_seq_py(
def compute_block_hash_for_seq(
tokens: List[int],
kv_block_size: int,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None
......@@ -300,7 +300,7 @@ def compute_block_hash_for_seq_py(
... "mm_hash": 0xDEADBEEF,
... }]
... }
>>> hashes = compute_block_hash_for_seq_py(tokens, 32, [mm_info])
>>> hashes = compute_block_hash_for_seq(tokens, 32, [mm_info])
"""
...
......@@ -694,24 +694,25 @@ class KvEventPublisher:
kv_block_size: int = 0,
dp_rank: int = 0,
enable_local_indexer: bool = False,
zmq_config: Optional[ZmqKvEventPublisherConfig] = None,
zmq_endpoint: Optional[str] = None,
zmq_topic: Optional[str] = None,
) -> None:
"""
Create a `KvEventPublisher` object.
When zmq_config is provided, the publisher subscribes to a ZMQ socket for
When zmq_endpoint is provided, the publisher subscribes to a ZMQ socket for
incoming engine events (e.g. from SGLang/vLLM) and relays them to NATS.
The zmq_config fields override kv_block_size, dp_rank, and enable_local_indexer.
When zmq_config is None, events are pushed manually via publish_stored/publish_removed.
When zmq_endpoint is None, events are pushed manually via publish_stored/publish_removed.
Args:
component: The component to publish events for
worker_id: The worker ID (unused, inferred from component)
kv_block_size: The KV block size (must be > 0; ignored if zmq_config is set)
dp_rank: The data parallel rank (defaults to 0; ignored if zmq_config is set)
enable_local_indexer: Enable worker-local KV indexer (ignored if zmq_config is set)
zmq_config: Optional ZMQ configuration for relay mode
kv_block_size: The KV block size (must be > 0)
dp_rank: The data parallel rank (defaults to 0)
enable_local_indexer: Enable worker-local KV indexer
zmq_endpoint: Optional ZMQ endpoint for relay mode (e.g. "tcp://127.0.0.1:5557")
zmq_topic: ZMQ topic to subscribe to (defaults to "" when zmq_endpoint is set)
"""
def publish_stored(
......@@ -753,28 +754,6 @@ class KvEventPublisher:
"""
...
class ZmqKvEventPublisherConfig:
def __init__(
self,
worker_id: int,
kv_block_size: int,
zmq_endpoint: str = "tcp://127.0.0.1:5557",
zmq_topic: str = "",
enable_local_indexer: bool = True,
dp_rank: int = 0
) -> None:
"""
ZMQ configuration for KvEventPublisher relay mode.
:param worker_id: The worker ID.
:param kv_block_size: The block size for the key-value store.
:param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557".
:param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string.
:param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to True.
:param dp_rank: The data parallel rank for this publisher. Defaults to 0.
"""
...
class HttpService:
"""
A HTTP service for dynamo applications.
......@@ -1344,9 +1323,9 @@ class ZmqKvEventListener:
"""
...
class KvPushRouter:
class KvRouter:
"""
A KV-aware push router that performs intelligent routing based on KV cache overlap.
A KV-aware router that performs intelligent routing based on KV cache overlap.
"""
def __init__(
......@@ -1356,7 +1335,7 @@ class KvPushRouter:
kv_router_config: KvRouterConfig,
) -> None:
"""
Create a new KvPushRouter instance.
Create a new KvRouter instance.
Args:
endpoint: The endpoint to connect to for routing requests
......
......@@ -11,7 +11,7 @@ from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpService as HttpService
from dynamo._core import KserveGrpcService as KserveGrpcService
from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvPushRouter as KvPushRouter
from dynamo._core import KvRouter as KvRouter
from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import LoRADownloader as LoRADownloader
from dynamo._core import MediaDecoder as MediaDecoder
......@@ -28,8 +28,7 @@ from dynamo._core import RouterConfig as RouterConfig
from dynamo._core import RouterMode as RouterMode
from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher
from dynamo._core import ZmqKvEventListener as ZmqKvEventListener
from dynamo._core import ZmqKvEventPublisherConfig as ZmqKvEventPublisherConfig
from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for_seq_py
from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq
from dynamo._core import fetch_llm as fetch_llm
from dynamo._core import lora_name_to_id as lora_name_to_id
from dynamo._core import make_engine
......
......@@ -23,7 +23,7 @@ from typing import Any
import pytest
from dynamo.llm import RadixTree, compute_block_hash_for_seq_py
from dynamo.llm import RadixTree, compute_block_hash_for_seq
pytestmark = pytest.mark.pre_merge
......@@ -215,17 +215,17 @@ def test_mm_block_hash_computation_basic():
tokens = [100] * DEFAULT_BLOCK_SIZE
# Without MM info
hashes_no_mm = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE)
hashes_no_mm = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE)
assert len(hashes_no_mm) == 1
# With MM info 1
hashes_mm1 = compute_block_hash_for_seq_py(
hashes_mm1 = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)]
)
assert len(hashes_mm1) == 1
# With MM info 2
hashes_mm2 = compute_block_hash_for_seq_py(
hashes_mm2 = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_2)]
)
assert len(hashes_mm2) == 1
......@@ -242,8 +242,8 @@ def test_mm_block_hash_determinism():
tokens = [100] * DEFAULT_BLOCK_SIZE
mm_info = [make_mm_info(MM_HASH_1)]
hash1 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_info)
hash2 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_info)
hash1 = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, mm_info)
hash2 = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, mm_info)
assert hash1 == hash2
......@@ -261,7 +261,7 @@ def test_mm_block_hash_multiple_blocks(block_size: int):
# One MM info per block
mm_infos = [make_mm_info(MM_HASH_1) for _ in range(num_blocks)]
hashes = compute_block_hash_for_seq_py(tokens, block_size, mm_infos)
hashes = compute_block_hash_for_seq(tokens, block_size, mm_infos)
assert len(hashes) == num_blocks
# Each block should have a unique hash (due to different tokens)
......@@ -277,7 +277,7 @@ def test_mm_block_hash_partial_block():
# MM info for each block
mm_infos = [make_mm_info(MM_HASH_1), make_mm_info(MM_HASH_2)]
hashes = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_infos)
hashes = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, mm_infos)
# Only complete blocks get hashes - partial blocks are not hashed
assert len(hashes) == 1
......@@ -291,10 +291,8 @@ def test_mm_block_hash_none_mm_info():
# Pass None for some blocks' MM info
mm_infos = [None]
hashes_with_none = compute_block_hash_for_seq_py(
tokens, DEFAULT_BLOCK_SIZE, mm_infos
)
hashes_without = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE)
hashes_with_none = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, mm_infos)
hashes_without = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE)
# Both should produce the same result
assert hashes_with_none == hashes_without
......@@ -309,8 +307,8 @@ def test_mm_block_hash_different_offsets():
mm_info_1 = make_mm_info(MM_HASH_1, offsets=[[0, 10]])
mm_info_2 = make_mm_info(MM_HASH_1, offsets=[[5, 15]])
hash1 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info_1])
hash2 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info_2])
hash1 = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, [mm_info_1])
hash2 = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, [mm_info_2])
# Currently offsets are not included in hash computation - just mm_hash
# This behavior may change - update test if needed
......@@ -330,12 +328,12 @@ def test_mm_block_hash_multiple_mm_objects():
]
}
hashes = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info])
hashes = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, [mm_info])
assert len(hashes) == 1
# Compare with single MM object
single_mm_hashes = compute_block_hash_for_seq_py(
single_mm_hashes = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)]
)
......@@ -349,7 +347,7 @@ def test_mm_block_hash_error_zero_block_size():
tokens = [100] * 32
with pytest.raises(ValueError, match="kv_block_size cannot be 0"):
compute_block_hash_for_seq_py(tokens, 0)
compute_block_hash_for_seq(tokens, 0)
# =============================================================================
......@@ -364,10 +362,10 @@ def test_integration_mm_hash_to_routing():
tokens = [100] * DEFAULT_BLOCK_SIZE
# Compute hashes for two different MM contents
hash_mm1 = compute_block_hash_for_seq_py(
hash_mm1 = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)]
)[0]
hash_mm2 = compute_block_hash_for_seq_py(
hash_mm2 = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_2)]
)[0]
......@@ -406,7 +404,7 @@ def test_integration_multiple_workers_same_tokens(num_workers: int):
# Store blocks for each worker
for worker_id, mm_hash in enumerate(mm_hashes):
block_hash = compute_block_hash_for_seq_py(
block_hash = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(mm_hash)]
)[0]
......@@ -423,7 +421,7 @@ def test_integration_multiple_workers_same_tokens(num_workers: int):
# Query for each worker's block should match only that worker
for worker_id, mm_hash in enumerate(mm_hashes):
block_hash = compute_block_hash_for_seq_py(
block_hash = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(mm_hash)]
)[0]
......
......@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
import aiohttp
import nats
from dynamo._core import DistributedRuntime, KvPushRouter, KvRouterConfig
from dynamo._core import DistributedRuntime, KvRouter, KvRouterConfig
from tests.utils.managed_process import ManagedProcess
if TYPE_CHECKING:
......@@ -188,7 +188,7 @@ async def wait_for_frontend_ready(
2. Sends a test POST to /v1/chat/completions to verify the request pipeline is functional
Use this when testing through the HTTP frontend server (dynamo.frontend).
For direct Python API testing with KvPushRouter, use wait_for_workers_ready() instead.
For direct Python API testing with KvRouter, use wait_for_workers_ready() instead.
Args:
frontend_url: Base URL of the frontend HTTP server (e.g., "http://localhost:8000")
......@@ -276,7 +276,7 @@ async def wait_for_frontend_ready(
async def wait_for_workers_ready(
endpoint,
router: KvPushRouter,
router: KvRouter,
expected_num_workers: int,
model_name: str,
) -> list[int]:
......@@ -289,7 +289,7 @@ async def wait_for_workers_ready(
Args:
endpoint: The endpoint object to get the client from
router: The KvPushRouter to use for sending warmup requests
router: The KvRouter to use for sending warmup requests
expected_num_workers: Number of workers to wait for
Returns:
......@@ -493,7 +493,7 @@ async def send_inflight_requests(urls: list, payload: dict, num_requests: int):
async def send_request_via_python_kv_router(
kv_python_router: KvPushRouter,
kv_python_router: KvRouter,
model_name: str,
token_ids: list,
initial_wait: float,
......@@ -609,7 +609,7 @@ async def send_request_via_python_kv_router(
)
logger.debug(
f"Successfully verified {max_tokens} tokens generated as expected via KvPushRouter with ignore_eos=True"
f"Successfully verified {max_tokens} tokens generated as expected via KvRouter with ignore_eos=True"
)
if return_worker_ids:
......@@ -883,9 +883,9 @@ def _test_python_router_bindings(
model_name: str,
num_workers: int,
):
"""Test KvPushRouter Python bindings with token streaming and config overrides.
"""Test KvRouter Python bindings with token streaming and config overrides.
Assumes engine_workers are already initialized. This test creates a KvPushRouter
Assumes engine_workers are already initialized. This test creates a KvRouter
Python object and sends three test requests to verify:
1. Token streaming with full router config overrides (overlap_score_weight, router_temperature)
2. Token streaming without any overrides (uses default config)
......@@ -906,19 +906,17 @@ def _test_python_router_bindings(
# Create KvRouterConfig with default settings
kv_router_config = KvRouterConfig()
# Create KvPushRouter Python object
kv_push_router = KvPushRouter(
# Create KvRouter Python object
kv_router = KvRouter(
endpoint=endpoint,
block_size=block_size,
kv_router_config=kv_router_config,
)
logger.info("Created KvPushRouter Python object")
logger.info("Created KvRouter Python object")
# Wait for workers to be ready
asyncio.run(
wait_for_workers_ready(endpoint, kv_push_router, num_workers, model_name)
)
asyncio.run(wait_for_workers_ready(endpoint, kv_router, num_workers, model_name))
# Generate random token IDs (100 to 200 tokens)
num_input_tokens = random.randint(100, 200)
......@@ -936,7 +934,7 @@ def _test_python_router_bindings(
logger.info(f"Testing with full router config overrides: {router_config_override}")
asyncio.run(
send_request_via_python_kv_router(
kv_python_router=kv_push_router,
kv_python_router=kv_router,
model_name=model_name,
token_ids=token_ids,
initial_wait=1.0,
......@@ -958,7 +956,7 @@ def _test_python_router_bindings(
logger.info("Testing without router config overrides")
asyncio.run(
send_request_via_python_kv_router(
kv_python_router=kv_push_router,
kv_python_router=kv_router,
model_name=model_name,
token_ids=token_ids[:50], # Use fewer tokens for second test,
initial_wait=1.0,
......@@ -981,7 +979,7 @@ def _test_python_router_bindings(
logger.info(f"Testing with partial router config overrides: {partial_override}")
asyncio.run(
send_request_via_python_kv_router(
kv_python_router=kv_push_router,
kv_python_router=kv_router,
model_name=model_name,
token_ids=token_ids[:30], # Use fewer tokens for third test,
initial_wait=1.0,
......@@ -999,7 +997,7 @@ def _test_python_router_bindings(
)
)
logger.info("KvPushRouter bindings test completed successfully")
logger.info("KvRouter bindings test completed successfully")
def _test_router_query_instance_id(
......@@ -1310,8 +1308,8 @@ def _test_router_indexers_sync(
"""Test that two KV routers have synchronized indexer states after processing requests.
Assumes engine_workers are already initialized. This test:
1. Creates first KvPushRouter (with its own runtime) and sends 25 requests (triggers snapshot at threshold=20)
2. Creates second KvPushRouter (with its own runtime, should sync from NATS snapshot)
1. Creates first KvRouter (with its own runtime) and sends 25 requests (triggers snapshot at threshold=20)
2. Creates second KvRouter (with its own runtime, should sync from NATS snapshot)
3. Sends 25 requests to second router
4. Verifies NATS object store contains the snapshot
5. Dumps states from both routers and compares them (should be identical)
......@@ -1395,23 +1393,21 @@ def _test_router_indexers_sync(
component1 = namespace1.component(engine_workers.component_name)
endpoint1 = component1.endpoint("generate")
kv_push_router1 = KvPushRouter(
kv_router1 = KvRouter(
endpoint=endpoint1,
block_size=block_size,
kv_router_config=kv_router_config,
)
# Wait for workers to be ready
await wait_for_workers_ready(
endpoint1, kv_push_router1, num_workers, model_name
)
await wait_for_workers_ready(endpoint1, kv_router1, num_workers, model_name)
# Send 25 requests to first router
logger.info("Sending 25 requests to first router")
# Send requests to first router
successful1 = await send_requests_to_router(
kv_push_router1, 25, "Router 1", endpoint1
kv_router1, 25, "Router 1", endpoint1
)
assert (
successful1 == 25
......@@ -1428,7 +1424,7 @@ def _test_router_indexers_sync(
logger.info("Sending 10 requests while NATS is down (via TCP)")
successful_offline1 = await send_requests_to_router(
kv_push_router1, 10, "Router 1 (NATS down)", endpoint1
kv_router1, 10, "Router 1 (NATS down)", endpoint1
)
assert (
successful_offline1 == 10
......@@ -1450,7 +1446,7 @@ def _test_router_indexers_sync(
component2 = namespace2.component(engine_workers.component_name)
endpoint2 = component2.endpoint("generate")
kv_push_router2 = KvPushRouter(
kv_router2 = KvRouter(
endpoint=endpoint2,
block_size=block_size,
kv_router_config=kv_router_config,
......@@ -1459,7 +1455,7 @@ def _test_router_indexers_sync(
# Send 25 requests to second router with initial retry loop
logger.info("Sending 25 requests to second router")
successful2 = await send_requests_to_router(
kv_push_router2, 25, "Router 2", endpoint2
kv_router2, 25, "Router 2", endpoint2
)
assert (
successful2 == 25
......@@ -1476,7 +1472,7 @@ def _test_router_indexers_sync(
logger.info("Sending 10 requests while NATS is down (via TCP)")
successful_offline2 = await send_requests_to_router(
kv_push_router2, 10, "Router 2 (NATS down)", endpoint2
kv_router2, 10, "Router 2 (NATS down)", endpoint2
)
assert (
successful_offline2 == 10
......@@ -1488,7 +1484,7 @@ def _test_router_indexers_sync(
logger.info("Sending 5 more requests after NATS recovery")
successful_recovery = await send_requests_to_router(
kv_push_router1, 5, "Router 1 (post-recovery)", endpoint1
kv_router1, 5, "Router 1 (post-recovery)", endpoint1
)
assert (
successful_recovery == 5
......@@ -1551,8 +1547,8 @@ def _test_router_indexers_sync(
# Dump states from both routers
logger.info("Dumping states from both routers")
state1_json = await kv_push_router1.dump_events()
state2_json = await kv_push_router2.dump_events()
state1_json = await kv_router1.dump_events()
state2_json = await kv_router2.dump_events()
# Parse JSON strings for comparison
state1 = json.loads(state1_json)
......@@ -1916,7 +1912,7 @@ def _test_router_decisions(
durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
)
kv_push_router = KvPushRouter(
kv_router = KvRouter(
endpoint=endpoint,
block_size=block_size,
kv_router_config=kv_router_config,
......@@ -1930,7 +1926,7 @@ def _test_router_decisions(
# Wait for workers to be ready and get their instance IDs
worker_ids = await wait_for_workers_ready(
endpoint,
kv_push_router,
kv_router,
expected_num_workers=expected_num_instances,
model_name=model_name,
)
......@@ -1976,7 +1972,7 @@ def _test_router_decisions(
logger.info(log_msg)
result = await send_request_via_python_kv_router(
kv_python_router=kv_push_router,
kv_python_router=kv_router,
model_name=model_name,
token_ids=request,
initial_wait=1.0,
......@@ -2004,7 +2000,7 @@ def _test_router_decisions(
await asyncio.sleep(1)
# Dump events from the router
events_json = await kv_push_router.dump_events()
events_json = await kv_router.dump_events()
return events_json, forced_worker_id, forced_dp_rank, response_worker_ids
# Run the async test
......
......@@ -479,15 +479,15 @@ def test_mocker_kv_router_overload_503(
@pytest.mark.parametrize(
"durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer)
def test_kv_push_router_bindings(
def test_kv_router_bindings(
request,
runtime_services_dynamic_ports,
predownload_tokenizers,
request_plane,
durable_kv_events,
):
"""Test KvPushRouter Python bindings with mocker engines."""
logger.info("Starting KvPushRouter bindings test")
"""Test KvRouter Python bindings with mocker engines."""
logger.info("Starting KvRouter bindings test")
# Use local indexer (NATS Core mode)
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment