Unverified Commit 14eceb43 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: rename KvPushRouter to KvRouter in python + more bindings removal (#6238)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 52271760
...@@ -26,7 +26,7 @@ from vllm.v1.engine.output_processor import OutputProcessor, OutputProcessorOutp ...@@ -26,7 +26,7 @@ from vllm.v1.engine.output_processor import OutputProcessor, OutputProcessorOutp
from dynamo.frontend.frontend_args import FrontendConfig from dynamo.frontend.frontend_args import FrontendConfig
from dynamo.llm import ( from dynamo.llm import (
KvPushRouter, KvRouter,
ModelCardInstanceId, ModelCardInstanceId,
ModelDeploymentCard, ModelDeploymentCard,
PythonAsyncEngine, PythonAsyncEngine,
...@@ -77,7 +77,7 @@ class VllmProcessor: ...@@ -77,7 +77,7 @@ class VllmProcessor:
self, self,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
input_processor: InputProcessor, input_processor: InputProcessor,
router, # Client or KvPushRouter router, # Client or KvRouter
output_processor: OutputProcessor, output_processor: OutputProcessor,
tool_parser_class: type[ToolParser] | None, tool_parser_class: type[ToolParser] | None,
reasoning_parser_class: type[ReasoningParser] | None, reasoning_parser_class: type[ReasoningParser] | None,
...@@ -85,7 +85,7 @@ class VllmProcessor: ...@@ -85,7 +85,7 @@ class VllmProcessor:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.input_processor = input_processor self.input_processor = input_processor
self.router = router self.router = router
self.is_kv_router = isinstance(router, KvPushRouter) self.is_kv_router = isinstance(router, KvRouter)
self.output_processor = output_processor self.output_processor = output_processor
self.tool_parser_class = tool_parser_class self.tool_parser_class = tool_parser_class
self.reasoning_parser_class = reasoning_parser_class self.reasoning_parser_class = reasoning_parser_class
...@@ -445,7 +445,7 @@ class EngineFactory: ...@@ -445,7 +445,7 @@ class EngineFactory:
) )
if self.router_config.router_mode == RouterMode.KV: if self.router_config.router_mode == RouterMode.KV:
router = KvPushRouter( router = KvRouter(
endpoint=generate_endpoint, endpoint=generate_endpoint,
block_size=self.config.kv_cache_block_size or 16, block_size=self.config.kv_cache_block_size or 16,
kv_router_config=self.router_config.kv_router_config, kv_router_config=self.router_config.kv_router_config,
......
...@@ -20,7 +20,7 @@ from typing import Optional ...@@ -20,7 +20,7 @@ from typing import Optional
import uvloop import uvloop
from dynamo.llm import KvPushRouter, KvRouterConfig from dynamo.llm import KvRouter, KvRouterConfig
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -42,7 +42,7 @@ class StandaloneRouterHandler: ...@@ -42,7 +42,7 @@ class StandaloneRouterHandler:
self.worker_endpoint_path = worker_endpoint_path self.worker_endpoint_path = worker_endpoint_path
self.block_size = block_size self.block_size = block_size
self.kv_router_config = kv_router_config self.kv_router_config = kv_router_config
self.kv_push_router: Optional[KvPushRouter] = None self.kv_router: Optional[KvRouter] = None
self.worker_client: Optional[Client] = None self.worker_client: Optional[Client] = None
async def initialize(self): async def initialize(self):
...@@ -66,15 +66,14 @@ class StandaloneRouterHandler: ...@@ -66,15 +66,14 @@ class StandaloneRouterHandler:
self.worker_client = await worker_endpoint.client() self.worker_client = await worker_endpoint.client()
# Create KvPushRouter with specified configuration self.kv_router = KvRouter(
self.kv_push_router = KvPushRouter(
endpoint=worker_endpoint, endpoint=worker_endpoint,
block_size=self.block_size, block_size=self.block_size,
kv_router_config=self.kv_router_config, kv_router_config=self.kv_router_config,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize KvPushRouter: {e}") logger.error(f"Failed to initialize KvRouter: {e}")
raise raise
async def generate(self, request): async def generate(self, request):
...@@ -85,11 +84,11 @@ class StandaloneRouterHandler: ...@@ -85,11 +84,11 @@ class StandaloneRouterHandler:
Wraps the request into PreprocessedRequest format and wraps worker responses Wraps the request into PreprocessedRequest format and wraps worker responses
into LLMEngineOutput format. into LLMEngineOutput format.
""" """
if self.kv_push_router is None: if self.kv_router is None:
logger.error("KvPushRouter not initialized - cannot process request") logger.error("KvRouter not initialized - cannot process request")
raise RuntimeError("Router not initialized") raise RuntimeError("Router not initialized")
# Wrap incoming request into PreprocessedRequest format for KvPushRouter # Wrap incoming request into PreprocessedRequest format for KvRouter
# The request should already have most fields, but we ensure it has the structure # The request should already have most fields, but we ensure it has the structure
# Build routing hints from request (supports both nested routing object and legacy dp_rank) # Build routing hints from request (supports both nested routing object and legacy dp_rank)
routing = request.get("routing") routing = request.get("routing")
...@@ -112,8 +111,7 @@ class StandaloneRouterHandler: ...@@ -112,8 +111,7 @@ class StandaloneRouterHandler:
"extra_args": request.get("extra_args"), "extra_args": request.get("extra_args"),
} }
# Route and process through KvPushRouter async for worker_output in await self.kv_router.generate_from_request(
async for worker_output in await self.kv_push_router.generate_from_request(
preprocessed_request preprocessed_request
): ):
# Wrap worker output into LLMEngineOutput format # Wrap worker output into LLMEngineOutput format
...@@ -142,11 +140,11 @@ class StandaloneRouterHandler: ...@@ -142,11 +140,11 @@ class StandaloneRouterHandler:
overlap, but does NOT actually route the request or update router states. overlap, but does NOT actually route the request or update router states.
It's useful for debugging, monitoring, or implementing custom routing logic. It's useful for debugging, monitoring, or implementing custom routing logic.
""" """
if self.kv_push_router is None: if self.kv_router is None:
logger.error("KvPushRouter not initialized - cannot get best worker") logger.error("KvRouter not initialized - cannot get best worker")
raise RuntimeError("Router not initialized") raise RuntimeError("Router not initialized")
(worker_id, _dp_rank, _overlap_blocks) = await self.kv_push_router.best_worker( (worker_id, _dp_rank, _overlap_blocks) = await self.kv_router.best_worker(
token_ids, router_config_override token_ids, router_config_override
) )
...@@ -220,6 +218,14 @@ def parse_args(): ...@@ -220,6 +218,14 @@ def parse_args():
help="KV Router: Reset router state on startup, purging stream and object store. By default, states are persisted. WARNING: This can affect existing router replicas (default: False)", help="KV Router: Reset router state on startup, purging stream and object store. By default, states are persisted. WARNING: This can affect existing router replicas (default: False)",
) )
parser.add_argument(
"--durable-kv-events",
action="store_true",
dest="durable_kv_events",
default=False,
help="KV Router: Enable durable KV events using NATS JetStream instead of NATS Core. By default, the router uses the generic event plane (NATS Core or ZMQ) with local_indexer mode. Use this flag when you need durability and multi-replica consistency. Requires NATS with JetStream enabled.",
)
parser.add_argument( parser.add_argument(
"--no-track-active-blocks", "--no-track-active-blocks",
action="store_false", action="store_false",
...@@ -228,6 +234,14 @@ def parse_args(): ...@@ -228,6 +234,14 @@ def parse_args():
help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing (default: True)", help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing (default: True)",
) )
parser.add_argument(
"--no-assume-kv-reuse",
action="store_false",
dest="router_assume_kv_reuse",
default=True,
help="KV Router: When tracking active blocks, do not assume KV cache reuse (generate random hashes instead of computing actual block hashes). Useful when KV cache reuse is not expected. By default, KV cache reuse is assumed.",
)
parser.add_argument( parser.add_argument(
"--track-output-blocks", "--track-output-blocks",
action="store_true", action="store_true",
...@@ -288,10 +302,12 @@ async def worker(runtime: DistributedRuntime): ...@@ -288,10 +302,12 @@ async def worker(runtime: DistributedRuntime):
f"overlap_score_weight={args.kv_overlap_score_weight}, " f"overlap_score_weight={args.kv_overlap_score_weight}, "
f"router_temperature={args.router_temperature}, " f"router_temperature={args.router_temperature}, "
f"use_kv_events={args.use_kv_events}, " f"use_kv_events={args.use_kv_events}, "
f"durable_kv_events={args.durable_kv_events}, "
f"router_replica_sync={args.router_replica_sync}, " f"router_replica_sync={args.router_replica_sync}, "
f"router_reset_states={args.router_reset_states}, " f"router_reset_states={args.router_reset_states}, "
f"router_track_active_blocks={args.router_track_active_blocks}, " f"router_track_active_blocks={args.router_track_active_blocks}, "
f"router_track_output_blocks={args.router_track_output_blocks}, " f"router_track_output_blocks={args.router_track_output_blocks}, "
f"router_assume_kv_reuse={args.router_assume_kv_reuse}, "
f"router_ttl_secs={args.router_ttl_secs}, " f"router_ttl_secs={args.router_ttl_secs}, "
f"router_max_tree_size={args.router_max_tree_size}, " f"router_max_tree_size={args.router_max_tree_size}, "
f"router_prune_target_ratio={args.router_prune_target_ratio}" f"router_prune_target_ratio={args.router_prune_target_ratio}"
...@@ -302,11 +318,13 @@ async def worker(runtime: DistributedRuntime): ...@@ -302,11 +318,13 @@ async def worker(runtime: DistributedRuntime):
overlap_score_weight=args.kv_overlap_score_weight, overlap_score_weight=args.kv_overlap_score_weight,
router_temperature=args.router_temperature, router_temperature=args.router_temperature,
use_kv_events=args.use_kv_events, use_kv_events=args.use_kv_events,
durable_kv_events=args.durable_kv_events,
router_replica_sync=args.router_replica_sync, router_replica_sync=args.router_replica_sync,
router_snapshot_threshold=args.router_snapshot_threshold,
router_reset_states=args.router_reset_states,
router_track_active_blocks=args.router_track_active_blocks, router_track_active_blocks=args.router_track_active_blocks,
router_track_output_blocks=args.router_track_output_blocks, router_track_output_blocks=args.router_track_output_blocks,
router_assume_kv_reuse=args.router_assume_kv_reuse,
router_snapshot_threshold=args.router_snapshot_threshold,
router_reset_states=args.router_reset_states,
router_ttl_secs=args.router_ttl_secs, router_ttl_secs=args.router_ttl_secs,
router_max_tree_size=args.router_max_tree_size, router_max_tree_size=args.router_max_tree_size,
router_prune_target_ratio=args.router_prune_target_ratio, router_prune_target_ratio=args.router_prune_target_ratio,
......
...@@ -20,11 +20,7 @@ from dynamo.common.utils.prometheus import ( ...@@ -20,11 +20,7 @@ from dynamo.common.utils.prometheus import (
LLMBackendMetrics, LLMBackendMetrics,
register_engine_metrics_callback, register_engine_metrics_callback,
) )
from dynamo.llm import ( from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
KvEventPublisher,
WorkerMetricsPublisher,
ZmqKvEventPublisherConfig,
)
from dynamo.runtime import Component, Endpoint from dynamo.runtime import Component, Endpoint
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
...@@ -256,19 +252,17 @@ class DynamoSglangPublisher: ...@@ -256,19 +252,17 @@ class DynamoSglangPublisher:
zmq_ep = format_zmq_endpoint(zmq_ep, local_ip) zmq_ep = format_zmq_endpoint(zmq_ep, local_ip)
zmq_config = ZmqKvEventPublisherConfig(
worker_id=self.generate_endpoint.connection_id(),
kv_block_size=self.server_args.page_size,
zmq_endpoint=zmq_ep,
enable_local_indexer=self.dynamo_args.enable_local_indexer,
dp_rank=dp_rank,
)
logging.info( logging.info(
f"Setting up ZMQ kv event subscriber for dp_rank={dp_rank} " f"Setting up ZMQ kv event subscriber for dp_rank={dp_rank} "
f"(connecting to {zmq_ep})" f"(connecting to {zmq_ep})"
) )
publisher = KvEventPublisher( publisher = KvEventPublisher(
component=self.component, zmq_config=zmq_config component=self.component,
kv_block_size=self.server_args.page_size,
zmq_endpoint=zmq_ep,
zmq_topic="",
enable_local_indexer=self.dynamo_args.enable_local_indexer,
dp_rank=dp_rank,
) )
self.kv_publishers.append(publisher) self.kv_publishers.append(publisher)
......
...@@ -42,7 +42,6 @@ from dynamo.llm import ( ...@@ -42,7 +42,6 @@ from dynamo.llm import (
ModelInput, ModelInput,
ModelRuntimeConfig, ModelRuntimeConfig,
ModelType, ModelType,
ZmqKvEventPublisherConfig,
register_llm, register_llm,
) )
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
...@@ -476,14 +475,11 @@ async def init_llm_worker( ...@@ -476,14 +475,11 @@ async def init_llm_worker(
consolidator_publisher = None consolidator_publisher = None
if consolidator_output_endpoint: if consolidator_output_endpoint:
# Use the connect endpoint directly (already provided by get_consolidator_endpoints) # Use the connect endpoint directly (already provided by get_consolidator_endpoints)
consolidator_config = ZmqKvEventPublisherConfig( consolidator_publisher = KvEventPublisher(
worker_id=int(endpoint.connection_id()), component,
kv_block_size=config.kv_block_size, kv_block_size=config.kv_block_size,
zmq_endpoint=consolidator_output_connect_endpoint, zmq_endpoint=consolidator_output_connect_endpoint,
zmq_topic="", # Empty topic = all topics zmq_topic="",
)
consolidator_publisher = KvEventPublisher(
component, zmq_config=consolidator_config
) )
logging.info( logging.info(
f"Created worker-side publisher for consolidated events: " f"Created worker-side publisher for consolidated events: "
......
...@@ -28,7 +28,6 @@ from dynamo.llm import ( ...@@ -28,7 +28,6 @@ from dynamo.llm import (
ModelInput, ModelInput,
ModelRuntimeConfig, ModelRuntimeConfig,
ModelType, ModelType,
ZmqKvEventPublisherConfig,
fetch_llm, fetch_llm,
register_llm, register_llm,
) )
...@@ -341,14 +340,14 @@ def setup_kv_event_publisher( ...@@ -341,14 +340,14 @@ def setup_kv_event_publisher(
f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}" f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
) )
zmq_config = ZmqKvEventPublisherConfig( kv_publisher = KvEventPublisher(
worker_id=generate_endpoint.connection_id(), component=component,
kv_block_size=vllm_config.cache_config.block_size, kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint, zmq_endpoint=zmq_endpoint,
zmq_topic="",
enable_local_indexer=config.enable_local_indexer, enable_local_indexer=config.enable_local_indexer,
dp_rank=dp_rank, dp_rank=dp_rank,
) )
kv_publisher = KvEventPublisher(component=component, zmq_config=zmq_config)
kv_publishers.append(kv_publisher) kv_publishers.append(kv_publisher)
logger.info( logger.info(
......
...@@ -10,23 +10,23 @@ For quick start instructions, see the [Router README](README.md). This document ...@@ -10,23 +10,23 @@ For quick start instructions, see the [Router README](README.md). This document
## Table of Contents ## Table of Contents
- [Using KvPushRouter Python API](#using-kvpushrouter-python-api) - [Using KvRouter Python API](#using-kvrouter-python-api)
- [K8s Examples](#k8s-examples) - [K8s Examples](#k8s-examples)
- [Routing Patterns](#routing-patterns) - [Routing Patterns](#routing-patterns)
- [Custom Routing Example: Minimizing TTFT](#custom-routing-example-minimizing-ttft) - [Custom Routing Example: Minimizing TTFT](#custom-routing-example-minimizing-ttft)
- [KV Event Publishing for Custom Engines](#kv-event-publishing-for-custom-engines) - [KV Event Publishing for Custom Engines](#kv-event-publishing-for-custom-engines)
- [Global Router (Hierarchical Routing)](#global-router-hierarchical-routing) - [Global Router (Hierarchical Routing)](#global-router-hierarchical-routing)
## Using KvPushRouter Python API ## Using KvRouter Python API
Instead of launching the KV Router via command line, you can create a `KvPushRouter` object directly in Python. This allows per-request routing configuration overrides. Instead of launching the KV Router via command line, you can create a `KvRouter` object directly in Python. This allows per-request routing configuration overrides.
>[!Warning] >[!Warning]
> **Multiple Routers in Same Process**: If you need to run multiple `KvPushRouter` instances for fault tolerance or load distribution, you must launch them in **separate processes** (e.g., using `python -m dynamo.frontend` with different ports). Creating multiple `KvPushRouter` objects in the same Python process is not supported - they share the same cancellation token from the component's primary lease, so dropping one router will cancel all routers in that process. For in-process routing, use a single `KvPushRouter` instance. > **Multiple Routers in Same Process**: If you need to run multiple `KvRouter` instances for fault tolerance or load distribution, you must launch them in **separate processes** (e.g., using `python -m dynamo.frontend` with different ports). Creating multiple `KvRouter` objects in the same Python process is not supported - they share the same cancellation token from the component's primary lease, so dropping one router will cancel all routers in that process. For in-process routing, use a single `KvRouter` instance.
### Methods ### Methods
The `KvPushRouter` provides the following methods: The `KvRouter` provides the following methods:
- **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management. - **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management.
...@@ -53,7 +53,7 @@ python -m dynamo.vllm --model meta-llama/Llama-2-7b-hf ...@@ -53,7 +53,7 @@ python -m dynamo.vllm --model meta-llama/Llama-2-7b-hf
```python ```python
import asyncio import asyncio
from dynamollm import DistributedRuntime, KvPushRouter, KvRouterConfig from dynamollm import DistributedRuntime, KvRouter, KvRouterConfig
async def main(): async def main():
# Get runtime and create endpoint # Get runtime and create endpoint
...@@ -64,7 +64,7 @@ async def main(): ...@@ -64,7 +64,7 @@ async def main():
# Create KV router # Create KV router
kv_router_config = KvRouterConfig() kv_router_config = KvRouterConfig()
router = KvPushRouter( router = KvRouter(
endpoint=endpoint, endpoint=endpoint,
block_size=16, block_size=16,
kv_router_config=kv_router_config kv_router_config=kv_router_config
...@@ -163,7 +163,7 @@ extraPodSpec: ...@@ -163,7 +163,7 @@ extraPodSpec:
## Routing Patterns ## Routing Patterns
The `KvPushRouter` supports multiple usage patterns depending on your control requirements: The `KvRouter` supports multiple usage patterns depending on your control requirements:
### 1. Automatic Routing (Recommended) ### 1. Automatic Routing (Recommended)
Call `generate()` directly and let the router handle everything: Call `generate()` directly and let the router handle everything:
...@@ -222,7 +222,7 @@ Here's an example of using `get_potential_loads()` to implement custom routing t ...@@ -222,7 +222,7 @@ Here's an example of using `get_potential_loads()` to implement custom routing t
```python ```python
import asyncio import asyncio
from dynamo.llm import DistributedRuntime, KvPushRouter, KvRouterConfig from dynamo.llm import DistributedRuntime, KvRouter, KvRouterConfig
async def minimize_ttft_routing(): async def minimize_ttft_routing():
# Setup router # Setup router
...@@ -231,7 +231,7 @@ async def minimize_ttft_routing(): ...@@ -231,7 +231,7 @@ async def minimize_ttft_routing():
component = namespace.component("backend") component = namespace.component("backend")
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
router = KvPushRouter( router = KvRouter(
endpoint=endpoint, endpoint=endpoint,
block_size=16, block_size=16,
kv_router_config=KvRouterConfig() kv_router_config=KvRouterConfig()
...@@ -447,25 +447,19 @@ flowchart LR ...@@ -447,25 +447,19 @@ flowchart LR
#### Part 1: ZMQ Subscriber (Dynamo Bindings) #### Part 1: ZMQ Subscriber (Dynamo Bindings)
If your engine already publishes to ZMQ, use `ZmqKvEventPublisher` to subscribe and forward to NATS: If your engine already publishes to ZMQ, use `KvEventPublisher` with `zmq_endpoint` to subscribe and forward to NATS:
```python ```python
from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig from dynamo.llm import KvEventPublisher
# Configure the ZMQ subscriber # Create publisher - it automatically subscribes to ZMQ and forwards to NATS
config = ZmqKvEventPublisherConfig( kv_publisher = KvEventPublisher(
worker_id=endpoint.connection_id(), component=component,
kv_block_size=block_size, kv_block_size=block_size,
zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes
zmq_topic="", # Subscribe to all topics zmq_topic="", # Subscribe to all topics
enable_local_indexer=False, enable_local_indexer=False,
) )
# Create publisher - it automatically subscribes to ZMQ and forwards to NATS
kv_publisher = ZmqKvEventPublisher(
component=component,
config=config,
)
``` ```
#### Part 2: ZMQ Publisher (Pure Python) #### Part 2: ZMQ Publisher (Pure Python)
......
...@@ -185,23 +185,17 @@ flowchart LR ...@@ -185,23 +185,17 @@ flowchart LR
### Part 1: ZMQ Subscriber (Dynamo Bindings) ### Part 1: ZMQ Subscriber (Dynamo Bindings)
If your engine already publishes to ZMQ, use `KvEventPublisher` with a `ZmqKvEventPublisherConfig` to subscribe and forward to NATS: If your engine already publishes to ZMQ, use `KvEventPublisher` with `zmq_endpoint` (and optional `zmq_topic`) to subscribe and forward to NATS:
```python ```python
from dynamo.llm import KvEventPublisher, ZmqKvEventPublisherConfig from dynamo.llm import KvEventPublisher
# Configure the ZMQ subscriber
config = ZmqKvEventPublisherConfig(
worker_id=endpoint.connection_id(),
kv_block_size=block_size,
zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes
zmq_topic="", # Subscribe to all topics
)
# Create publisher - it automatically subscribes to ZMQ and forwards to NATS # Create publisher - it automatically subscribes to ZMQ and forwards to NATS
kv_publisher = KvEventPublisher( kv_publisher = KvEventPublisher(
component=component, component=component,
zmq_config=config, kv_block_size=block_size,
zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes
zmq_topic="", # Subscribe to all topics
) )
``` ```
......
...@@ -27,7 +27,7 @@ from tensorrt_llm.llmapi.tokenizer import tokenizer_factory ...@@ -27,7 +27,7 @@ from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from transformers import AutoProcessor from transformers import AutoProcessor
from worker import TrtllmWorkers from worker import TrtllmWorkers
from dynamo._core import compute_block_hash_for_seq_py from dynamo._core import compute_block_hash_for_seq
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -570,7 +570,7 @@ class ServiceAPI: ...@@ -570,7 +570,7 @@ class ServiceAPI:
processed.image_offsets_list, processed.image_offsets_list,
) )
logger.debug(f"block_mm_infos: {block_mm_infos}") logger.debug(f"block_mm_infos: {block_mm_infos}")
local_hashes = compute_block_hash_for_seq_py( local_hashes = compute_block_hash_for_seq(
processed.tokens, self.init_params.block_size, block_mm_infos processed.tokens, self.init_params.block_size, block_mm_infos
) )
......
...@@ -19,7 +19,7 @@ from dataclasses import dataclass ...@@ -19,7 +19,7 @@ from dataclasses import dataclass
import httpx import httpx
from dynamo.llm import compute_block_hash_for_seq_py from dynamo.llm import compute_block_hash_for_seq
# Sample test images from COCO dataset # Sample test images from COCO dataset
TEST_IMAGE_1 = "http://images.cocodataset.org/test2017/000000155781.jpg" TEST_IMAGE_1 = "http://images.cocodataset.org/test2017/000000155781.jpg"
...@@ -340,7 +340,7 @@ class KvRouterTests: ...@@ -340,7 +340,7 @@ class KvRouterTests:
def _test_mm_hash_computation(self): def _test_mm_hash_computation(self):
""" """
Test: Verify that compute_block_hash_for_seq_py produces different hashes Test: Verify that compute_block_hash_for_seq produces different hashes
for same tokens with different mm_hash values. for same tokens with different mm_hash values.
""" """
print("\n[MM-1] MM Hash Computation Test") print("\n[MM-1] MM Hash Computation Test")
...@@ -351,15 +351,15 @@ class KvRouterTests: ...@@ -351,15 +351,15 @@ class KvRouterTests:
block_size = 32 block_size = 32
# Hash without MM info # Hash without MM info
hash_no_mm = compute_block_hash_for_seq_py(tokens, block_size) hash_no_mm = compute_block_hash_for_seq(tokens, block_size)
# Hash with MM info (simulated mm_hash) # Hash with MM info (simulated mm_hash)
mm_info_1 = {"mm_objects": [{"mm_hash": 0xDEADBEEF, "offsets": [[0, 32]]}]} mm_info_1 = {"mm_objects": [{"mm_hash": 0xDEADBEEF, "offsets": [[0, 32]]}]}
hash_with_mm1 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info_1]) hash_with_mm1 = compute_block_hash_for_seq(tokens, block_size, [mm_info_1])
# Hash with different MM info # Hash with different MM info
mm_info_2 = {"mm_objects": [{"mm_hash": 0xCAFEBABE, "offsets": [[0, 32]]}]} mm_info_2 = {"mm_objects": [{"mm_hash": 0xCAFEBABE, "offsets": [[0, 32]]}]}
hash_with_mm2 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info_2]) hash_with_mm2 = compute_block_hash_for_seq(tokens, block_size, [mm_info_2])
self.log(f"Hash without MM: {hash_no_mm}") self.log(f"Hash without MM: {hash_no_mm}")
self.log(f"Hash with MM 1: {hash_with_mm1}") self.log(f"Hash with MM 1: {hash_with_mm1}")
...@@ -410,7 +410,7 @@ class KvRouterTests: ...@@ -410,7 +410,7 @@ class KvRouterTests:
mm_info_a = { mm_info_a = {
"mm_objects": [{"mm_hash": 0x1111111111111111, "offsets": [[0, 64]]}] "mm_objects": [{"mm_hash": 0x1111111111111111, "offsets": [[0, 64]]}]
} }
hashes_a = compute_block_hash_for_seq_py( hashes_a = compute_block_hash_for_seq(
tokens, block_size, [mm_info_a, mm_info_a] tokens, block_size, [mm_info_a, mm_info_a]
) )
...@@ -418,7 +418,7 @@ class KvRouterTests: ...@@ -418,7 +418,7 @@ class KvRouterTests:
mm_info_b = { mm_info_b = {
"mm_objects": [{"mm_hash": 0x2222222222222222, "offsets": [[0, 64]]}] "mm_objects": [{"mm_hash": 0x2222222222222222, "offsets": [[0, 64]]}]
} }
hashes_b = compute_block_hash_for_seq_py( hashes_b = compute_block_hash_for_seq(
tokens, block_size, [mm_info_b, mm_info_b] tokens, block_size, [mm_info_b, mm_info_b]
) )
...@@ -459,9 +459,9 @@ class KvRouterTests: ...@@ -459,9 +459,9 @@ class KvRouterTests:
mm_info = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 32]]}]} mm_info = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 32]]}]}
# Compute hash multiple times # Compute hash multiple times
hash1 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info]) hash1 = compute_block_hash_for_seq(tokens, block_size, [mm_info])
hash2 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info]) hash2 = compute_block_hash_for_seq(tokens, block_size, [mm_info])
hash3 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info]) hash3 = compute_block_hash_for_seq(tokens, block_size, [mm_info])
self.log(f"Hash 1: {hash1}") self.log(f"Hash 1: {hash1}")
self.log(f"Hash 2: {hash2}") self.log(f"Hash 2: {hash2}")
...@@ -499,19 +499,19 @@ class KvRouterTests: ...@@ -499,19 +499,19 @@ class KvRouterTests:
# Image covers first block only # Image covers first block only
mm_info_first = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 32]]}]} mm_info_first = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 32]]}]}
hash_first = compute_block_hash_for_seq_py( hash_first = compute_block_hash_for_seq(
tokens, block_size, [mm_info_first, None] tokens, block_size, [mm_info_first, None]
) )
# Image covers second block only # Image covers second block only
mm_info_second = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[32, 64]]}]} mm_info_second = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[32, 64]]}]}
hash_second = compute_block_hash_for_seq_py( hash_second = compute_block_hash_for_seq(
tokens, block_size, [None, mm_info_second] tokens, block_size, [None, mm_info_second]
) )
# Image covers both blocks # Image covers both blocks
mm_info_both = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 64]]}]} mm_info_both = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 64]]}]}
hash_both = compute_block_hash_for_seq_py( hash_both = compute_block_hash_for_seq(
tokens, block_size, [mm_info_both, mm_info_both] tokens, block_size, [mm_info_both, mm_info_both]
) )
...@@ -555,12 +555,12 @@ class KvRouterTests: ...@@ -555,12 +555,12 @@ class KvRouterTests:
# MM info only applies to middle block # MM info only applies to middle block
mm_info = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[32, 64]]}]} mm_info = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[32, 64]]}]}
hashes_with_mm = compute_block_hash_for_seq_py( hashes_with_mm = compute_block_hash_for_seq(
tokens, block_size, [None, mm_info, None] tokens, block_size, [None, mm_info, None]
) )
# No MM info # No MM info
hashes_without_mm = compute_block_hash_for_seq_py(tokens, block_size, None) hashes_without_mm = compute_block_hash_for_seq(tokens, block_size, None)
self.log(f"Hashes with MM: {hashes_with_mm}") self.log(f"Hashes with MM: {hashes_with_mm}")
self.log(f"Hashes without MM: {hashes_without_mm}") self.log(f"Hashes without MM: {hashes_without_mm}")
......
...@@ -23,7 +23,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser ...@@ -23,7 +23,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo.llm import KvEventPublisher, ZmqKvEventPublisherConfig from dynamo.llm import KvEventPublisher
from dynamo.runtime import Component, DistributedRuntime, Endpoint, dynamo_worker from dynamo.runtime import Component, DistributedRuntime, Endpoint, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -163,12 +163,11 @@ class VllmBaseWorker: ...@@ -163,12 +163,11 @@ class VllmBaseWorker:
data_parallel_rank=self.engine_args.data_parallel_rank or 0, data_parallel_rank=self.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1") ).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig( self.kv_publisher = KvEventPublisher(
worker_id=endpoint.connection_id(), component=component,
kv_block_size=vllm_config.cache_config.block_size, kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint, zmq_endpoint=zmq_endpoint,
) )
self.kv_publisher = KvEventPublisher(component=component, zmq_config=zmq_config)
logger.info(f"Reading Events from {zmq_endpoint}") logger.info(f"Reading Events from {zmq_endpoint}")
......
...@@ -170,15 +170,13 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -170,15 +170,13 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::KvEventPublisher>()?; m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<llm::kv::RadixTree>()?; m.add_class::<llm::kv::RadixTree>()?;
m.add_class::<llm::kv::ZmqKvEventListener>()?; m.add_class::<llm::kv::ZmqKvEventListener>()?;
m.add_class::<llm::kv::ZmqKvEventPublisherConfig>()?;
m.add_class::<llm::lora::LoRADownloader>()?; m.add_class::<llm::lora::LoRADownloader>()?;
m.add_class::<http::HttpService>()?; m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpAsyncEngine>()?; m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<context::Context>()?; m.add_class::<context::Context>()?;
m.add_class::<ModelType>()?; m.add_class::<ModelType>()?;
m.add_class::<ModelInput>()?; m.add_class::<ModelInput>()?;
m.add_class::<llm::kv::KvPushRouter>()?; m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvPushRouterStream>()?;
m.add_class::<RouterMode>()?; m.add_class::<RouterMode>()?;
m.add_class::<kserve_grpc::KserveGrpcService>()?; m.add_class::<kserve_grpc::KserveGrpcService>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?; m.add("__version__", env!("CARGO_PKG_VERSION"))?;
...@@ -987,10 +985,7 @@ impl Client { ...@@ -987,10 +985,7 @@ impl Client {
_ => client.round_robin(request_ctx).await.map_err(to_pyerr)?, _ => client.round_robin(request_ctx).await.map_err(to_pyerr)?,
}; };
tokio::spawn(process_stream(stream, tx)); tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream { Ok(AsyncResponseStream::new(rx, annotated))
rx: Arc::new(Mutex::new(rx)),
annotated,
})
}) })
} }
...@@ -1024,10 +1019,7 @@ impl Client { ...@@ -1024,10 +1019,7 @@ impl Client {
_ => client.random(request_ctx).await.map_err(to_pyerr)?, _ => client.random(request_ctx).await.map_err(to_pyerr)?,
}; };
tokio::spawn(process_stream(stream, tx)); tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream { Ok(AsyncResponseStream::new(rx, annotated))
rx: Arc::new(Mutex::new(rx)),
annotated,
})
}) })
} }
...@@ -1068,10 +1060,7 @@ impl Client { ...@@ -1068,10 +1060,7 @@ impl Client {
tokio::spawn(process_stream(stream, tx)); tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream { Ok(AsyncResponseStream::new(rx, annotated))
rx: Arc::new(Mutex::new(rx)),
annotated,
})
}) })
} }
} }
...@@ -1106,11 +1095,23 @@ async fn process_stream( ...@@ -1106,11 +1095,23 @@ async fn process_stream(
} }
#[pyclass] #[pyclass]
struct AsyncResponseStream { pub(crate) struct AsyncResponseStream {
rx: Arc<Mutex<tokio::sync::mpsc::Receiver<RsAnnotated<PyObject>>>>, rx: Arc<Mutex<tokio::sync::mpsc::Receiver<RsAnnotated<PyObject>>>>,
annotated: bool, annotated: bool,
} }
impl AsyncResponseStream {
pub(crate) fn new(
rx: tokio::sync::mpsc::Receiver<RsAnnotated<PyObject>>,
annotated: bool,
) -> Self {
Self {
rx: Arc::new(Mutex::new(rx)),
annotated,
}
}
}
#[pymethods] #[pymethods]
impl AsyncResponseStream { impl AsyncResponseStream {
/// This method is required to implement the `AsyncIterator` protocol. /// This method is required to implement the `AsyncIterator` protocol.
......
...@@ -12,8 +12,10 @@ use super::*; ...@@ -12,8 +12,10 @@ use super::*;
use crate::Component; use crate::Component;
use llm_rs::kv_router::protocols::compute_block_hash_for_seq; use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
use rs::pipeline::{AsyncEngine, SingleIn}; use rs::pipeline::{AsyncEngine, SingleIn};
use rs::protocols::annotated::Annotated as RsAnnotated;
use tracing; use tracing;
use llm_rs::kv_router::KvPushRouter as RsKvPushRouter;
use llm_rs::kv_router::protocols::*; use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener}; use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener};
use llm_rs::protocols::common::timing::RequestTracker; use llm_rs::protocols::common::timing::RequestTracker;
...@@ -21,7 +23,7 @@ use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; ...@@ -21,7 +23,7 @@ use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use serde_json::json; use serde_json::json;
#[pyfunction] #[pyfunction]
#[pyo3(signature = (tokens, kv_block_size, block_mm_infos=None))] #[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None))]
pub fn compute_block_hash_for_seq_py( pub fn compute_block_hash_for_seq_py(
_py: Python, _py: Python,
tokens: Vec<u32>, tokens: Vec<u32>,
...@@ -99,54 +101,6 @@ impl WorkerMetricsPublisher { ...@@ -99,54 +101,6 @@ impl WorkerMetricsPublisher {
} }
} }
#[pyclass]
#[derive(Clone)]
pub struct ZmqKvEventPublisherConfig {
#[pyo3(get, set)]
pub worker_id: WorkerId,
#[pyo3(get, set)]
pub kv_block_size: usize,
#[pyo3(get, set)]
pub zmq_endpoint: String,
#[pyo3(get, set)]
pub zmq_topic: String,
#[pyo3(get, set)]
pub enable_local_indexer: bool, // whether the underlying KvEventPublisher publishes to
// both global and worker-local KvIndexers
#[pyo3(get, set)]
pub dp_rank: DpRank, // data parallel rank for this publisher
}
#[pymethods]
impl ZmqKvEventPublisherConfig {
#[new]
#[pyo3(signature = (
worker_id,
kv_block_size,
zmq_endpoint = "tcp://127.0.0.1:5557".to_string(),
zmq_topic = "".to_string(),
enable_local_indexer = true,
dp_rank = 0
))]
pub fn new(
worker_id: WorkerId,
kv_block_size: usize,
zmq_endpoint: String,
zmq_topic: String,
enable_local_indexer: bool,
dp_rank: DpRank,
) -> Self {
Self {
worker_id,
kv_block_size,
zmq_endpoint,
zmq_topic,
enable_local_indexer,
dp_rank,
}
}
}
/// A ZMQ-based key-value cache event listener that operates independently /// A ZMQ-based key-value cache event listener that operates independently
/// of the dynamo runtime or event plane infrastructure. /// of the dynamo runtime or event plane infrastructure.
#[pyclass] #[pyclass]
...@@ -231,33 +185,22 @@ pub(crate) struct KvEventPublisher { ...@@ -231,33 +185,22 @@ pub(crate) struct KvEventPublisher {
#[pymethods] #[pymethods]
impl KvEventPublisher { impl KvEventPublisher {
#[new] #[new]
#[pyo3(signature = (component, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_config=None))] #[pyo3(signature = (component, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_endpoint=None, zmq_topic=None))]
fn new( fn new(
component: Component, component: Component,
worker_id: WorkerId, worker_id: WorkerId,
kv_block_size: usize, kv_block_size: usize,
dp_rank: DpRank, dp_rank: DpRank,
enable_local_indexer: bool, enable_local_indexer: bool,
zmq_config: Option<ZmqKvEventPublisherConfig>, zmq_endpoint: Option<String>,
zmq_topic: Option<String>,
) -> PyResult<Self> { ) -> PyResult<Self> {
// worker_id is not used; connection_id is inferred from the component.
let _ = worker_id; let _ = worker_id;
// When zmq_config is provided, use its fields for kv_block_size/dp_rank/enable_local_indexer let source_config = zmq_endpoint.map(|endpoint| KvEventSourceConfig::Zmq {
let (kv_block_size, dp_rank, enable_local_indexer, source_config) = endpoint,
if let Some(ref cfg) = zmq_config { topic: zmq_topic.unwrap_or_default(),
( });
cfg.kv_block_size,
cfg.dp_rank,
cfg.enable_local_indexer,
Some(KvEventSourceConfig::Zmq {
endpoint: cfg.zmq_endpoint.clone(),
topic: cfg.zmq_topic.clone(),
}),
)
} else {
(kv_block_size, dp_rank, enable_local_indexer, None)
};
if kv_block_size == 0 { if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
...@@ -719,8 +662,8 @@ async fn create_kv_router_from_endpoint( ...@@ -719,8 +662,8 @@ async fn create_kv_router_from_endpoint(
} }
#[pyclass] #[pyclass]
pub(crate) struct KvPushRouter { pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvPushRouter>, inner: Arc<RsKvPushRouter>,
} }
/// Inject worker_id info from tracker into response's disaggregated_params. /// Inject worker_id info from tracker into response's disaggregated_params.
...@@ -749,27 +692,25 @@ fn inject_worker_id_from_tracker( ...@@ -749,27 +692,25 @@ fn inject_worker_id_from_tracker(
} }
// TODO: can this reuse the stream conversion method in Client bindings? // TODO: can this reuse the stream conversion method in Client bindings?
impl KvPushRouter { impl KvRouter {
/// Helper method to process a request and create a Python async generator /// Helper method to process a request and create a Python async generator
fn process_request_to_stream<'p>( fn process_request_to_stream<'p>(
py: Python<'p>, py: Python<'p>,
inner: Arc<llm_rs::kv_router::KvPushRouter>, inner: Arc<RsKvPushRouter>,
request: llm_rs::protocols::common::preprocessor::PreprocessedRequest, request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
tracker: Option<Arc<RequestTracker>>, tracker: Option<Arc<RequestTracker>>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let single_in = SingleIn::new(request); let single_in = SingleIn::new(request);
let stream = inner.generate(single_in).await.map_err(to_pyerr)?; let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
let (tx, rx) = tokio::sync::mpsc::channel(100); let (tx, rx) = tokio::sync::mpsc::channel::<RsAnnotated<PyObject>>(100);
// Spawn a task to process the stream
tokio::spawn(async move { tokio::spawn(async move {
let mut stream = stream; let mut stream = stream;
let mut first_item = true; let mut first_item = true;
let mut first_token_gauges_observed = false; let mut first_token_gauges_observed = false;
while let Some(mut response) = stream.next().await { while let Some(mut response) = stream.next().await {
// Inject worker_id into first response if tracker is available
if first_item { if first_item {
first_item = false; first_item = false;
if let (Some(tracker), Some(data)) = (&tracker, &mut response.data) { if let (Some(tracker), Some(data)) = (&tracker, &mut response.data) {
...@@ -777,7 +718,6 @@ impl KvPushRouter { ...@@ -777,7 +718,6 @@ impl KvPushRouter {
} }
} }
// Observe per-worker TTFT/ISL gauges on first response with actual tokens
if !first_token_gauges_observed { if !first_token_gauges_observed {
let has_tokens = response let has_tokens = response
.data .data
...@@ -792,7 +732,6 @@ impl KvPushRouter { ...@@ -792,7 +732,6 @@ impl KvPushRouter {
} }
} }
// Convert LLMEngineOutput to PyObject
let py_response = Python::with_gil(|py| { let py_response = Python::with_gil(|py| {
pythonize(py, &response.data) pythonize(py, &response.data)
.map(|obj| obj.unbind()) .map(|obj| obj.unbind())
...@@ -801,8 +740,8 @@ impl KvPushRouter { ...@@ -801,8 +740,8 @@ impl KvPushRouter {
match py_response { match py_response {
Ok(obj) => { Ok(obj) => {
if tx.send(obj).await.is_err() { if tx.send(RsAnnotated::from_data(obj)).await.is_err() {
break; // Receiver dropped break;
} }
} }
Err(e) => { Err(e) => {
...@@ -812,23 +751,19 @@ impl KvPushRouter { ...@@ -812,23 +751,19 @@ impl KvPushRouter {
} }
} }
// Observe per-worker ITL gauge at stream end
if let Some(ref tracker) = tracker { if let Some(ref tracker) = tracker {
tracker.observe_finish_gauges(); tracker.observe_finish_gauges();
} }
}); });
// Return a Python async generator wrapper Ok(crate::AsyncResponseStream::new(rx, false))
Ok(KvPushRouterStream {
rx: Arc::new(tokio::sync::Mutex::new(rx)),
})
}) })
} }
} }
#[pymethods] #[pymethods]
impl KvPushRouter { impl KvRouter {
/// Create a new KvPushRouter for KV-aware routing to workers. /// Create a new KvRouter for KV-aware routing to workers.
/// ///
/// # Arguments /// # Arguments
/// * `endpoint` - The endpoint to route requests to /// * `endpoint` - The endpoint to route requests to
...@@ -869,8 +804,7 @@ impl KvPushRouter { ...@@ -869,8 +804,7 @@ impl KvPushRouter {
) )
.await?; .await?;
// Create KvPushRouter (kv_router is already Arc<KvRouter>) let kv_push_router = RsKvPushRouter::new(push_router, kv_router);
let kv_push_router = llm_rs::kv_router::KvPushRouter::new(push_router, kv_router);
Ok(Self { Ok(Self {
inner: Arc::new(kv_push_router), inner: Arc::new(kv_push_router),
...@@ -1078,32 +1012,3 @@ impl KvPushRouter { ...@@ -1078,32 +1012,3 @@ impl KvPushRouter {
}) })
} }
} }
// Python async generator wrapper for the stream
#[pyclass]
pub(crate) struct KvPushRouterStream {
rx: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<PyObject>>>,
}
#[pymethods]
impl KvPushRouterStream {
#[pyo3(name = "__aiter__")]
fn aiter(slf: Bound<'_, Self>) -> PyResult<Py<PyAny>> {
Ok(slf.clone().into_any().unbind())
}
#[pyo3(name = "__anext__")]
fn anext<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let rx = self.rx.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut rx = rx.lock().await;
match rx.recv().await {
Some(obj) => Ok(obj),
None => Err(pyo3::exceptions::PyStopAsyncIteration::new_err(
"Stream exhausted",
)),
}
})
}
}
...@@ -265,7 +265,7 @@ class ModelCardInstanceId: ...@@ -265,7 +265,7 @@ class ModelCardInstanceId:
... ...
def compute_block_hash_for_seq_py( def compute_block_hash_for_seq(
tokens: List[int], tokens: List[int],
kv_block_size: int, kv_block_size: int,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None
...@@ -300,7 +300,7 @@ def compute_block_hash_for_seq_py( ...@@ -300,7 +300,7 @@ def compute_block_hash_for_seq_py(
... "mm_hash": 0xDEADBEEF, ... "mm_hash": 0xDEADBEEF,
... }] ... }]
... } ... }
>>> hashes = compute_block_hash_for_seq_py(tokens, 32, [mm_info]) >>> hashes = compute_block_hash_for_seq(tokens, 32, [mm_info])
""" """
... ...
...@@ -694,24 +694,25 @@ class KvEventPublisher: ...@@ -694,24 +694,25 @@ class KvEventPublisher:
kv_block_size: int = 0, kv_block_size: int = 0,
dp_rank: int = 0, dp_rank: int = 0,
enable_local_indexer: bool = False, enable_local_indexer: bool = False,
zmq_config: Optional[ZmqKvEventPublisherConfig] = None, zmq_endpoint: Optional[str] = None,
zmq_topic: Optional[str] = None,
) -> None: ) -> None:
""" """
Create a `KvEventPublisher` object. Create a `KvEventPublisher` object.
When zmq_config is provided, the publisher subscribes to a ZMQ socket for When zmq_endpoint is provided, the publisher subscribes to a ZMQ socket for
incoming engine events (e.g. from SGLang/vLLM) and relays them to NATS. incoming engine events (e.g. from SGLang/vLLM) and relays them to NATS.
The zmq_config fields override kv_block_size, dp_rank, and enable_local_indexer.
When zmq_config is None, events are pushed manually via publish_stored/publish_removed. When zmq_endpoint is None, events are pushed manually via publish_stored/publish_removed.
Args: Args:
component: The component to publish events for component: The component to publish events for
worker_id: The worker ID (unused, inferred from component) worker_id: The worker ID (unused, inferred from component)
kv_block_size: The KV block size (must be > 0; ignored if zmq_config is set) kv_block_size: The KV block size (must be > 0)
dp_rank: The data parallel rank (defaults to 0; ignored if zmq_config is set) dp_rank: The data parallel rank (defaults to 0)
enable_local_indexer: Enable worker-local KV indexer (ignored if zmq_config is set) enable_local_indexer: Enable worker-local KV indexer
zmq_config: Optional ZMQ configuration for relay mode zmq_endpoint: Optional ZMQ endpoint for relay mode (e.g. "tcp://127.0.0.1:5557")
zmq_topic: ZMQ topic to subscribe to (defaults to "" when zmq_endpoint is set)
""" """
def publish_stored( def publish_stored(
...@@ -753,28 +754,6 @@ class KvEventPublisher: ...@@ -753,28 +754,6 @@ class KvEventPublisher:
""" """
... ...
class ZmqKvEventPublisherConfig:
def __init__(
self,
worker_id: int,
kv_block_size: int,
zmq_endpoint: str = "tcp://127.0.0.1:5557",
zmq_topic: str = "",
enable_local_indexer: bool = True,
dp_rank: int = 0
) -> None:
"""
ZMQ configuration for KvEventPublisher relay mode.
:param worker_id: The worker ID.
:param kv_block_size: The block size for the key-value store.
:param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557".
:param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string.
:param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to True.
:param dp_rank: The data parallel rank for this publisher. Defaults to 0.
"""
...
class HttpService: class HttpService:
""" """
A HTTP service for dynamo applications. A HTTP service for dynamo applications.
...@@ -1344,9 +1323,9 @@ class ZmqKvEventListener: ...@@ -1344,9 +1323,9 @@ class ZmqKvEventListener:
""" """
... ...
class KvPushRouter: class KvRouter:
""" """
A KV-aware push router that performs intelligent routing based on KV cache overlap. A KV-aware router that performs intelligent routing based on KV cache overlap.
""" """
def __init__( def __init__(
...@@ -1356,7 +1335,7 @@ class KvPushRouter: ...@@ -1356,7 +1335,7 @@ class KvPushRouter:
kv_router_config: KvRouterConfig, kv_router_config: KvRouterConfig,
) -> None: ) -> None:
""" """
Create a new KvPushRouter instance. Create a new KvRouter instance.
Args: Args:
endpoint: The endpoint to connect to for routing requests endpoint: The endpoint to connect to for routing requests
......
...@@ -11,7 +11,7 @@ from dynamo._core import HttpAsyncEngine as HttpAsyncEngine ...@@ -11,7 +11,7 @@ from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpService as HttpService from dynamo._core import HttpService as HttpService
from dynamo._core import KserveGrpcService as KserveGrpcService from dynamo._core import KserveGrpcService as KserveGrpcService
from dynamo._core import KvEventPublisher as KvEventPublisher from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvPushRouter as KvPushRouter from dynamo._core import KvRouter as KvRouter
from dynamo._core import KvRouterConfig as KvRouterConfig from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import LoRADownloader as LoRADownloader from dynamo._core import LoRADownloader as LoRADownloader
from dynamo._core import MediaDecoder as MediaDecoder from dynamo._core import MediaDecoder as MediaDecoder
...@@ -28,8 +28,7 @@ from dynamo._core import RouterConfig as RouterConfig ...@@ -28,8 +28,7 @@ from dynamo._core import RouterConfig as RouterConfig
from dynamo._core import RouterMode as RouterMode from dynamo._core import RouterMode as RouterMode
from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher
from dynamo._core import ZmqKvEventListener as ZmqKvEventListener from dynamo._core import ZmqKvEventListener as ZmqKvEventListener
from dynamo._core import ZmqKvEventPublisherConfig as ZmqKvEventPublisherConfig from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq
from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for_seq_py
from dynamo._core import fetch_llm as fetch_llm from dynamo._core import fetch_llm as fetch_llm
from dynamo._core import lora_name_to_id as lora_name_to_id from dynamo._core import lora_name_to_id as lora_name_to_id
from dynamo._core import make_engine from dynamo._core import make_engine
......
...@@ -23,7 +23,7 @@ from typing import Any ...@@ -23,7 +23,7 @@ from typing import Any
import pytest import pytest
from dynamo.llm import RadixTree, compute_block_hash_for_seq_py from dynamo.llm import RadixTree, compute_block_hash_for_seq
pytestmark = pytest.mark.pre_merge pytestmark = pytest.mark.pre_merge
...@@ -215,17 +215,17 @@ def test_mm_block_hash_computation_basic(): ...@@ -215,17 +215,17 @@ def test_mm_block_hash_computation_basic():
tokens = [100] * DEFAULT_BLOCK_SIZE tokens = [100] * DEFAULT_BLOCK_SIZE
# Without MM info # Without MM info
hashes_no_mm = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE) hashes_no_mm = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE)
assert len(hashes_no_mm) == 1 assert len(hashes_no_mm) == 1
# With MM info 1 # With MM info 1
hashes_mm1 = compute_block_hash_for_seq_py( hashes_mm1 = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)] tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)]
) )
assert len(hashes_mm1) == 1 assert len(hashes_mm1) == 1
# With MM info 2 # With MM info 2
hashes_mm2 = compute_block_hash_for_seq_py( hashes_mm2 = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_2)] tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_2)]
) )
assert len(hashes_mm2) == 1 assert len(hashes_mm2) == 1
...@@ -242,8 +242,8 @@ def test_mm_block_hash_determinism(): ...@@ -242,8 +242,8 @@ def test_mm_block_hash_determinism():
tokens = [100] * DEFAULT_BLOCK_SIZE tokens = [100] * DEFAULT_BLOCK_SIZE
mm_info = [make_mm_info(MM_HASH_1)] mm_info = [make_mm_info(MM_HASH_1)]
hash1 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_info) hash1 = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, mm_info)
hash2 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_info) hash2 = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, mm_info)
assert hash1 == hash2 assert hash1 == hash2
...@@ -261,7 +261,7 @@ def test_mm_block_hash_multiple_blocks(block_size: int): ...@@ -261,7 +261,7 @@ def test_mm_block_hash_multiple_blocks(block_size: int):
# One MM info per block # One MM info per block
mm_infos = [make_mm_info(MM_HASH_1) for _ in range(num_blocks)] mm_infos = [make_mm_info(MM_HASH_1) for _ in range(num_blocks)]
hashes = compute_block_hash_for_seq_py(tokens, block_size, mm_infos) hashes = compute_block_hash_for_seq(tokens, block_size, mm_infos)
assert len(hashes) == num_blocks assert len(hashes) == num_blocks
# Each block should have a unique hash (due to different tokens) # Each block should have a unique hash (due to different tokens)
...@@ -277,7 +277,7 @@ def test_mm_block_hash_partial_block(): ...@@ -277,7 +277,7 @@ def test_mm_block_hash_partial_block():
# MM info for each block # MM info for each block
mm_infos = [make_mm_info(MM_HASH_1), make_mm_info(MM_HASH_2)] mm_infos = [make_mm_info(MM_HASH_1), make_mm_info(MM_HASH_2)]
hashes = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_infos) hashes = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, mm_infos)
# Only complete blocks get hashes - partial blocks are not hashed # Only complete blocks get hashes - partial blocks are not hashed
assert len(hashes) == 1 assert len(hashes) == 1
...@@ -291,10 +291,8 @@ def test_mm_block_hash_none_mm_info(): ...@@ -291,10 +291,8 @@ def test_mm_block_hash_none_mm_info():
# Pass None for some blocks' MM info # Pass None for some blocks' MM info
mm_infos = [None] mm_infos = [None]
hashes_with_none = compute_block_hash_for_seq_py( hashes_with_none = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, mm_infos)
tokens, DEFAULT_BLOCK_SIZE, mm_infos hashes_without = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE)
)
hashes_without = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE)
# Both should produce the same result # Both should produce the same result
assert hashes_with_none == hashes_without assert hashes_with_none == hashes_without
...@@ -309,8 +307,8 @@ def test_mm_block_hash_different_offsets(): ...@@ -309,8 +307,8 @@ def test_mm_block_hash_different_offsets():
mm_info_1 = make_mm_info(MM_HASH_1, offsets=[[0, 10]]) mm_info_1 = make_mm_info(MM_HASH_1, offsets=[[0, 10]])
mm_info_2 = make_mm_info(MM_HASH_1, offsets=[[5, 15]]) mm_info_2 = make_mm_info(MM_HASH_1, offsets=[[5, 15]])
hash1 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info_1]) hash1 = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, [mm_info_1])
hash2 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info_2]) hash2 = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, [mm_info_2])
# Currently offsets are not included in hash computation - just mm_hash # Currently offsets are not included in hash computation - just mm_hash
# This behavior may change - update test if needed # This behavior may change - update test if needed
...@@ -330,12 +328,12 @@ def test_mm_block_hash_multiple_mm_objects(): ...@@ -330,12 +328,12 @@ def test_mm_block_hash_multiple_mm_objects():
] ]
} }
hashes = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info]) hashes = compute_block_hash_for_seq(tokens, DEFAULT_BLOCK_SIZE, [mm_info])
assert len(hashes) == 1 assert len(hashes) == 1
# Compare with single MM object # Compare with single MM object
single_mm_hashes = compute_block_hash_for_seq_py( single_mm_hashes = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)] tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)]
) )
...@@ -349,7 +347,7 @@ def test_mm_block_hash_error_zero_block_size(): ...@@ -349,7 +347,7 @@ def test_mm_block_hash_error_zero_block_size():
tokens = [100] * 32 tokens = [100] * 32
with pytest.raises(ValueError, match="kv_block_size cannot be 0"): with pytest.raises(ValueError, match="kv_block_size cannot be 0"):
compute_block_hash_for_seq_py(tokens, 0) compute_block_hash_for_seq(tokens, 0)
# ============================================================================= # =============================================================================
...@@ -364,10 +362,10 @@ def test_integration_mm_hash_to_routing(): ...@@ -364,10 +362,10 @@ def test_integration_mm_hash_to_routing():
tokens = [100] * DEFAULT_BLOCK_SIZE tokens = [100] * DEFAULT_BLOCK_SIZE
# Compute hashes for two different MM contents # Compute hashes for two different MM contents
hash_mm1 = compute_block_hash_for_seq_py( hash_mm1 = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)] tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)]
)[0] )[0]
hash_mm2 = compute_block_hash_for_seq_py( hash_mm2 = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_2)] tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_2)]
)[0] )[0]
...@@ -406,7 +404,7 @@ def test_integration_multiple_workers_same_tokens(num_workers: int): ...@@ -406,7 +404,7 @@ def test_integration_multiple_workers_same_tokens(num_workers: int):
# Store blocks for each worker # Store blocks for each worker
for worker_id, mm_hash in enumerate(mm_hashes): for worker_id, mm_hash in enumerate(mm_hashes):
block_hash = compute_block_hash_for_seq_py( block_hash = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(mm_hash)] tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(mm_hash)]
)[0] )[0]
...@@ -423,7 +421,7 @@ def test_integration_multiple_workers_same_tokens(num_workers: int): ...@@ -423,7 +421,7 @@ def test_integration_multiple_workers_same_tokens(num_workers: int):
# Query for each worker's block should match only that worker # Query for each worker's block should match only that worker
for worker_id, mm_hash in enumerate(mm_hashes): for worker_id, mm_hash in enumerate(mm_hashes):
block_hash = compute_block_hash_for_seq_py( block_hash = compute_block_hash_for_seq(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(mm_hash)] tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(mm_hash)]
)[0] )[0]
......
...@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
import aiohttp import aiohttp
import nats import nats
from dynamo._core import DistributedRuntime, KvPushRouter, KvRouterConfig from dynamo._core import DistributedRuntime, KvRouter, KvRouterConfig
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -188,7 +188,7 @@ async def wait_for_frontend_ready( ...@@ -188,7 +188,7 @@ async def wait_for_frontend_ready(
2. Sends a test POST to /v1/chat/completions to verify the request pipeline is functional 2. Sends a test POST to /v1/chat/completions to verify the request pipeline is functional
Use this when testing through the HTTP frontend server (dynamo.frontend). Use this when testing through the HTTP frontend server (dynamo.frontend).
For direct Python API testing with KvPushRouter, use wait_for_workers_ready() instead. For direct Python API testing with KvRouter, use wait_for_workers_ready() instead.
Args: Args:
frontend_url: Base URL of the frontend HTTP server (e.g., "http://localhost:8000") frontend_url: Base URL of the frontend HTTP server (e.g., "http://localhost:8000")
...@@ -276,7 +276,7 @@ async def wait_for_frontend_ready( ...@@ -276,7 +276,7 @@ async def wait_for_frontend_ready(
async def wait_for_workers_ready( async def wait_for_workers_ready(
endpoint, endpoint,
router: KvPushRouter, router: KvRouter,
expected_num_workers: int, expected_num_workers: int,
model_name: str, model_name: str,
) -> list[int]: ) -> list[int]:
...@@ -289,7 +289,7 @@ async def wait_for_workers_ready( ...@@ -289,7 +289,7 @@ async def wait_for_workers_ready(
Args: Args:
endpoint: The endpoint object to get the client from endpoint: The endpoint object to get the client from
router: The KvPushRouter to use for sending warmup requests router: The KvRouter to use for sending warmup requests
expected_num_workers: Number of workers to wait for expected_num_workers: Number of workers to wait for
Returns: Returns:
...@@ -493,7 +493,7 @@ async def send_inflight_requests(urls: list, payload: dict, num_requests: int): ...@@ -493,7 +493,7 @@ async def send_inflight_requests(urls: list, payload: dict, num_requests: int):
async def send_request_via_python_kv_router( async def send_request_via_python_kv_router(
kv_python_router: KvPushRouter, kv_python_router: KvRouter,
model_name: str, model_name: str,
token_ids: list, token_ids: list,
initial_wait: float, initial_wait: float,
...@@ -609,7 +609,7 @@ async def send_request_via_python_kv_router( ...@@ -609,7 +609,7 @@ async def send_request_via_python_kv_router(
) )
logger.debug( logger.debug(
f"Successfully verified {max_tokens} tokens generated as expected via KvPushRouter with ignore_eos=True" f"Successfully verified {max_tokens} tokens generated as expected via KvRouter with ignore_eos=True"
) )
if return_worker_ids: if return_worker_ids:
...@@ -883,9 +883,9 @@ def _test_python_router_bindings( ...@@ -883,9 +883,9 @@ def _test_python_router_bindings(
model_name: str, model_name: str,
num_workers: int, num_workers: int,
): ):
"""Test KvPushRouter Python bindings with token streaming and config overrides. """Test KvRouter Python bindings with token streaming and config overrides.
Assumes engine_workers are already initialized. This test creates a KvPushRouter Assumes engine_workers are already initialized. This test creates a KvRouter
Python object and sends three test requests to verify: Python object and sends three test requests to verify:
1. Token streaming with full router config overrides (overlap_score_weight, router_temperature) 1. Token streaming with full router config overrides (overlap_score_weight, router_temperature)
2. Token streaming without any overrides (uses default config) 2. Token streaming without any overrides (uses default config)
...@@ -906,19 +906,17 @@ def _test_python_router_bindings( ...@@ -906,19 +906,17 @@ def _test_python_router_bindings(
# Create KvRouterConfig with default settings # Create KvRouterConfig with default settings
kv_router_config = KvRouterConfig() kv_router_config = KvRouterConfig()
# Create KvPushRouter Python object # Create KvRouter Python object
kv_push_router = KvPushRouter( kv_router = KvRouter(
endpoint=endpoint, endpoint=endpoint,
block_size=block_size, block_size=block_size,
kv_router_config=kv_router_config, kv_router_config=kv_router_config,
) )
logger.info("Created KvPushRouter Python object") logger.info("Created KvRouter Python object")
# Wait for workers to be ready # Wait for workers to be ready
asyncio.run( asyncio.run(wait_for_workers_ready(endpoint, kv_router, num_workers, model_name))
wait_for_workers_ready(endpoint, kv_push_router, num_workers, model_name)
)
# Generate random token IDs (100 to 200 tokens) # Generate random token IDs (100 to 200 tokens)
num_input_tokens = random.randint(100, 200) num_input_tokens = random.randint(100, 200)
...@@ -936,7 +934,7 @@ def _test_python_router_bindings( ...@@ -936,7 +934,7 @@ def _test_python_router_bindings(
logger.info(f"Testing with full router config overrides: {router_config_override}") logger.info(f"Testing with full router config overrides: {router_config_override}")
asyncio.run( asyncio.run(
send_request_via_python_kv_router( send_request_via_python_kv_router(
kv_python_router=kv_push_router, kv_python_router=kv_router,
model_name=model_name, model_name=model_name,
token_ids=token_ids, token_ids=token_ids,
initial_wait=1.0, initial_wait=1.0,
...@@ -958,7 +956,7 @@ def _test_python_router_bindings( ...@@ -958,7 +956,7 @@ def _test_python_router_bindings(
logger.info("Testing without router config overrides") logger.info("Testing without router config overrides")
asyncio.run( asyncio.run(
send_request_via_python_kv_router( send_request_via_python_kv_router(
kv_python_router=kv_push_router, kv_python_router=kv_router,
model_name=model_name, model_name=model_name,
token_ids=token_ids[:50], # Use fewer tokens for second test, token_ids=token_ids[:50], # Use fewer tokens for second test,
initial_wait=1.0, initial_wait=1.0,
...@@ -981,7 +979,7 @@ def _test_python_router_bindings( ...@@ -981,7 +979,7 @@ def _test_python_router_bindings(
logger.info(f"Testing with partial router config overrides: {partial_override}") logger.info(f"Testing with partial router config overrides: {partial_override}")
asyncio.run( asyncio.run(
send_request_via_python_kv_router( send_request_via_python_kv_router(
kv_python_router=kv_push_router, kv_python_router=kv_router,
model_name=model_name, model_name=model_name,
token_ids=token_ids[:30], # Use fewer tokens for third test, token_ids=token_ids[:30], # Use fewer tokens for third test,
initial_wait=1.0, initial_wait=1.0,
...@@ -999,7 +997,7 @@ def _test_python_router_bindings( ...@@ -999,7 +997,7 @@ def _test_python_router_bindings(
) )
) )
logger.info("KvPushRouter bindings test completed successfully") logger.info("KvRouter bindings test completed successfully")
def _test_router_query_instance_id( def _test_router_query_instance_id(
...@@ -1310,8 +1308,8 @@ def _test_router_indexers_sync( ...@@ -1310,8 +1308,8 @@ def _test_router_indexers_sync(
"""Test that two KV routers have synchronized indexer states after processing requests. """Test that two KV routers have synchronized indexer states after processing requests.
Assumes engine_workers are already initialized. This test: Assumes engine_workers are already initialized. This test:
1. Creates first KvPushRouter (with its own runtime) and sends 25 requests (triggers snapshot at threshold=20) 1. Creates first KvRouter (with its own runtime) and sends 25 requests (triggers snapshot at threshold=20)
2. Creates second KvPushRouter (with its own runtime, should sync from NATS snapshot) 2. Creates second KvRouter (with its own runtime, should sync from NATS snapshot)
3. Sends 25 requests to second router 3. Sends 25 requests to second router
4. Verifies NATS object store contains the snapshot 4. Verifies NATS object store contains the snapshot
5. Dumps states from both routers and compares them (should be identical) 5. Dumps states from both routers and compares them (should be identical)
...@@ -1395,23 +1393,21 @@ def _test_router_indexers_sync( ...@@ -1395,23 +1393,21 @@ def _test_router_indexers_sync(
component1 = namespace1.component(engine_workers.component_name) component1 = namespace1.component(engine_workers.component_name)
endpoint1 = component1.endpoint("generate") endpoint1 = component1.endpoint("generate")
kv_push_router1 = KvPushRouter( kv_router1 = KvRouter(
endpoint=endpoint1, endpoint=endpoint1,
block_size=block_size, block_size=block_size,
kv_router_config=kv_router_config, kv_router_config=kv_router_config,
) )
# Wait for workers to be ready # Wait for workers to be ready
await wait_for_workers_ready( await wait_for_workers_ready(endpoint1, kv_router1, num_workers, model_name)
endpoint1, kv_push_router1, num_workers, model_name
)
# Send 25 requests to first router # Send 25 requests to first router
logger.info("Sending 25 requests to first router") logger.info("Sending 25 requests to first router")
# Send requests to first router # Send requests to first router
successful1 = await send_requests_to_router( successful1 = await send_requests_to_router(
kv_push_router1, 25, "Router 1", endpoint1 kv_router1, 25, "Router 1", endpoint1
) )
assert ( assert (
successful1 == 25 successful1 == 25
...@@ -1428,7 +1424,7 @@ def _test_router_indexers_sync( ...@@ -1428,7 +1424,7 @@ def _test_router_indexers_sync(
logger.info("Sending 10 requests while NATS is down (via TCP)") logger.info("Sending 10 requests while NATS is down (via TCP)")
successful_offline1 = await send_requests_to_router( successful_offline1 = await send_requests_to_router(
kv_push_router1, 10, "Router 1 (NATS down)", endpoint1 kv_router1, 10, "Router 1 (NATS down)", endpoint1
) )
assert ( assert (
successful_offline1 == 10 successful_offline1 == 10
...@@ -1450,7 +1446,7 @@ def _test_router_indexers_sync( ...@@ -1450,7 +1446,7 @@ def _test_router_indexers_sync(
component2 = namespace2.component(engine_workers.component_name) component2 = namespace2.component(engine_workers.component_name)
endpoint2 = component2.endpoint("generate") endpoint2 = component2.endpoint("generate")
kv_push_router2 = KvPushRouter( kv_router2 = KvRouter(
endpoint=endpoint2, endpoint=endpoint2,
block_size=block_size, block_size=block_size,
kv_router_config=kv_router_config, kv_router_config=kv_router_config,
...@@ -1459,7 +1455,7 @@ def _test_router_indexers_sync( ...@@ -1459,7 +1455,7 @@ def _test_router_indexers_sync(
# Send 25 requests to second router with initial retry loop # Send 25 requests to second router with initial retry loop
logger.info("Sending 25 requests to second router") logger.info("Sending 25 requests to second router")
successful2 = await send_requests_to_router( successful2 = await send_requests_to_router(
kv_push_router2, 25, "Router 2", endpoint2 kv_router2, 25, "Router 2", endpoint2
) )
assert ( assert (
successful2 == 25 successful2 == 25
...@@ -1476,7 +1472,7 @@ def _test_router_indexers_sync( ...@@ -1476,7 +1472,7 @@ def _test_router_indexers_sync(
logger.info("Sending 10 requests while NATS is down (via TCP)") logger.info("Sending 10 requests while NATS is down (via TCP)")
successful_offline2 = await send_requests_to_router( successful_offline2 = await send_requests_to_router(
kv_push_router2, 10, "Router 2 (NATS down)", endpoint2 kv_router2, 10, "Router 2 (NATS down)", endpoint2
) )
assert ( assert (
successful_offline2 == 10 successful_offline2 == 10
...@@ -1488,7 +1484,7 @@ def _test_router_indexers_sync( ...@@ -1488,7 +1484,7 @@ def _test_router_indexers_sync(
logger.info("Sending 5 more requests after NATS recovery") logger.info("Sending 5 more requests after NATS recovery")
successful_recovery = await send_requests_to_router( successful_recovery = await send_requests_to_router(
kv_push_router1, 5, "Router 1 (post-recovery)", endpoint1 kv_router1, 5, "Router 1 (post-recovery)", endpoint1
) )
assert ( assert (
successful_recovery == 5 successful_recovery == 5
...@@ -1551,8 +1547,8 @@ def _test_router_indexers_sync( ...@@ -1551,8 +1547,8 @@ def _test_router_indexers_sync(
# Dump states from both routers # Dump states from both routers
logger.info("Dumping states from both routers") logger.info("Dumping states from both routers")
state1_json = await kv_push_router1.dump_events() state1_json = await kv_router1.dump_events()
state2_json = await kv_push_router2.dump_events() state2_json = await kv_router2.dump_events()
# Parse JSON strings for comparison # Parse JSON strings for comparison
state1 = json.loads(state1_json) state1 = json.loads(state1_json)
...@@ -1916,7 +1912,7 @@ def _test_router_decisions( ...@@ -1916,7 +1912,7 @@ def _test_router_decisions(
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads, router_event_threads=router_event_threads,
) )
kv_push_router = KvPushRouter( kv_router = KvRouter(
endpoint=endpoint, endpoint=endpoint,
block_size=block_size, block_size=block_size,
kv_router_config=kv_router_config, kv_router_config=kv_router_config,
...@@ -1930,7 +1926,7 @@ def _test_router_decisions( ...@@ -1930,7 +1926,7 @@ def _test_router_decisions(
# Wait for workers to be ready and get their instance IDs # Wait for workers to be ready and get their instance IDs
worker_ids = await wait_for_workers_ready( worker_ids = await wait_for_workers_ready(
endpoint, endpoint,
kv_push_router, kv_router,
expected_num_workers=expected_num_instances, expected_num_workers=expected_num_instances,
model_name=model_name, model_name=model_name,
) )
...@@ -1976,7 +1972,7 @@ def _test_router_decisions( ...@@ -1976,7 +1972,7 @@ def _test_router_decisions(
logger.info(log_msg) logger.info(log_msg)
result = await send_request_via_python_kv_router( result = await send_request_via_python_kv_router(
kv_python_router=kv_push_router, kv_python_router=kv_router,
model_name=model_name, model_name=model_name,
token_ids=request, token_ids=request,
initial_wait=1.0, initial_wait=1.0,
...@@ -2004,7 +2000,7 @@ def _test_router_decisions( ...@@ -2004,7 +2000,7 @@ def _test_router_decisions(
await asyncio.sleep(1) await asyncio.sleep(1)
# Dump events from the router # Dump events from the router
events_json = await kv_push_router.dump_events() events_json = await kv_router.dump_events()
return events_json, forced_worker_id, forced_dp_rank, response_worker_ids return events_json, forced_worker_id, forced_dp_rank, response_worker_ids
# Run the async test # Run the async test
......
...@@ -479,15 +479,15 @@ def test_mocker_kv_router_overload_503( ...@@ -479,15 +479,15 @@ def test_mocker_kv_router_overload_503(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"durable_kv_events", [False], indirect=True "durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer) ) # Use NATS Core (local indexer)
def test_kv_push_router_bindings( def test_kv_router_bindings(
request, request,
runtime_services_dynamic_ports, runtime_services_dynamic_ports,
predownload_tokenizers, predownload_tokenizers,
request_plane, request_plane,
durable_kv_events, durable_kv_events,
): ):
"""Test KvPushRouter Python bindings with mocker engines.""" """Test KvRouter Python bindings with mocker engines."""
logger.info("Starting KvPushRouter bindings test") logger.info("Starting KvRouter bindings test")
# Use local indexer (NATS Core mode) # Use local indexer (NATS Core mode)
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment