Unverified Commit 9b9536d0 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: make prefill router general (#3329)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 3d2d7e47
...@@ -96,14 +96,12 @@ In a **new terminal**, launch the Dynamo router using the Python CLI: ...@@ -96,14 +96,12 @@ In a **new terminal**, launch the Dynamo router using the Python CLI:
```bash ```bash
python -m dynamo.frontend \ python -m dynamo.frontend \
--router-mode kv \ --router-mode kv \
--kv-cache-block-size 64 \
--router-reset-states \ --router-reset-states \
--http-port 8000 --http-port 8000
``` ```
This starts the router with: This starts the router with:
- KV cache routing mode - KV cache routing mode
- Block size of 64 (**Important:** This should match the `--block-size` used by your engines)
- `--router-reset-states` flag to clear the event cache (JetStream) from previous runs (useful for single router benchmarking) - `--router-reset-states` flag to clear the event cache (JetStream) from previous runs (useful for single router benchmarking)
- HTTP port 8000 - HTTP port 8000
...@@ -114,33 +112,33 @@ python -m dynamo.frontend --help ...@@ -114,33 +112,33 @@ python -m dynamo.frontend --help
For detailed explanations of router arguments (especially KV cache routing parameters), see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md). For detailed explanations of router arguments (especially KV cache routing parameters), see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md).
#### Launching a Prefill Router (Optional) #### Launching a Standalone Router for Prefill Workers (Optional)
If you're using disaggregated serving with separate prefill and decode workers, you should also launch a prefill router. The prefill router handles routing prefill requests to dedicated prefill workers. When using a prefill router, it's recommended to start the frontend (decode router) with `--kv-overlap-score-weight 0` for pure load balancing (as prefix-aware routing is now handled by the prefill router): If you're using disaggregated serving with separate prefill and decode workers, you should also launch a standalone router for prefill workers. This router handles routing prefill requests to dedicated prefill workers. When using a standalone prefill router, it's recommended to start the frontend (decode router) with `--kv-overlap-score-weight 0` for pure load balancing (as prefix-aware routing is now handled by the standalone router):
```bash ```bash
# Start the decode router with pure load balancing # Start the decode router with pure load balancing
python -m dynamo.frontend \ python -m dynamo.frontend \
--router-mode kv \ --router-mode kv \
--kv-cache-block-size 64 \
--router-reset-states \ --router-reset-states \
--http-port 8000 \ --http-port 8000 \
--kv-overlap-score-weight 0 --kv-overlap-score-weight 0
# In another terminal, start the prefill router (currently only supports vLLM) # In another terminal, start the standalone router for prefill workers
python -m dynamo.vllm_prefill_router \ python -m dynamo.router \
--namespace dynamo \ --endpoint dynamo.prefill.generate \
--block-size 64 --block-size 64 \
--router-reset-states \
--no-track-active-blocks
``` ```
The prefill router will automatically coordinate with the decode router to handle request routing between prefill and decode workers. The `--router-reset-states` flag clears any previous state, and `--no-track-active-blocks` disables active block tracking (suitable for prefill-only routing where decode load is not relevant).
**Note**: If you're unsure whether your backend engines correctly emit KV events for certain models (e.g., hybrid models like gpt-oss or nemotron nano 2), use the `--no-kv-events` flag to disable KV event tracking and use approximate KV indexing instead: **Note**: If you're unsure whether your backend engines correctly emit KV events for certain models (e.g., hybrid models like gpt-oss or nemotron nano 2), use the `--no-kv-events` flag to disable KV event tracking and use approximate KV indexing instead:
```bash ```bash
python -m dynamo.frontend \ python -m dynamo.frontend \
--router-mode kv \ --router-mode kv \
--kv-cache-block-size 64 \
--http-port 8000 \ --http-port 8000 \
--no-kv-events --no-kv-events
``` ```
......
...@@ -18,10 +18,12 @@ python -m dynamo.frontend \ ...@@ -18,10 +18,12 @@ python -m dynamo.frontend \
--kv-overlap-score-weight 0 \ --kv-overlap-score-weight 0 \
--router-reset-states & --router-reset-states &
# run prefill router service # run standalone router service for prefill workers
python -m dynamo.vllm_prefill_router \ python -m dynamo.router \
--namespace dynamo \ --endpoint dynamo.prefill.generate \
--block-size $BLOCK_SIZE & --block-size $BLOCK_SIZE \
--router-reset-states \
--no-track-active-blocks &
# two decode workers # two decode workers
# --enforce-eager is added for quick deployment. for production use, need to remove this flag # --enforce-eager is added for quick deployment. for production use, need to remove this flag
......
...@@ -178,6 +178,13 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -178,6 +178,13 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
runtime, component, engine_client, default_sampling_params runtime, component, engine_client, default_sampling_params
) )
# Set up KV event publisher for prefix caching if enabled
kv_publisher = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict() health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
try: try:
...@@ -219,14 +226,14 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -219,14 +226,14 @@ async def init(runtime: DistributedRuntime, config: Config):
prefill_router_client = ( prefill_router_client = (
await runtime.namespace(config.namespace) await runtime.namespace(config.namespace)
.component("prefill_router") # TODO don't hardcode .component("router") # Standalone router for prefill workers
.endpoint("find_best_worker") .endpoint("find_best_worker")
.client() .client()
) )
prefill_router_free_client = ( prefill_router_free_client = (
await runtime.namespace(config.namespace) await runtime.namespace(config.namespace)
.component("prefill_router") # TODO don't hardcode .component("router") # Standalone router for prefill workers
.endpoint("free") .endpoint("free")
.client() .client()
) )
......
...@@ -167,6 +167,13 @@ def parse_args(): ...@@ -167,6 +167,13 @@ def parse_args():
default=False, default=False,
help="KV Router: Reset router state on startup, purging stream and object store. By default, states are persisted. WARNING: This can affect existing router replicas.", help="KV Router: Reset router state on startup, purging stream and object store. By default, states are persisted. WARNING: This can affect existing router replicas.",
) )
parser.add_argument(
"--no-track-active-blocks",
action="store_false",
dest="router_track_active_blocks",
default=True,
help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing.",
)
parser.add_argument( parser.add_argument(
"--busy-threshold", "--busy-threshold",
type=float, type=float,
...@@ -232,6 +239,7 @@ async def async_main(): ...@@ -232,6 +239,7 @@ async def async_main():
router_replica_sync=flags.router_replica_sync, router_replica_sync=flags.router_replica_sync,
router_snapshot_threshold=flags.router_snapshot_threshold, router_snapshot_threshold=flags.router_snapshot_threshold,
router_reset_states=flags.router_reset_states, router_reset_states=flags.router_reset_states,
router_track_active_blocks=flags.router_track_active_blocks,
) )
elif flags.router_mode == "random": elif flags.router_mode == "random":
router_mode = RouterMode.Random router_mode = RouterMode.Random
......
<!-- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 -->
# Standalone Router
A backend-agnostic standalone KV-aware router service for Dynamo deployments. For details on how KV-aware routing works, see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md).
## Overview
The standalone router provides configurable KV-aware routing for any set of workers in a Dynamo deployment. It can be used for disaggregated serving (e.g., routing to prefill workers), multi-tier architectures, or any scenario requiring intelligent KV cache-aware routing decisions.
This component is **fully configurable** and works with any Dynamo backend (vLLM, TensorRT-LLM, SGLang, etc.) and any worker endpoint.
## Usage
### Command Line
```bash
python -m dynamo.router \
--endpoint dynamo.prefill.generate \
--block-size 64 \
--router-reset-states \
--no-track-active-blocks
```
### Arguments
**Required:**
- `--endpoint`: Full endpoint path for workers in the format `namespace.component.endpoint` (e.g., `dynamo.prefill.generate`)
**Router Configuration:**
For detailed descriptions of all KV router configuration options including `--block-size`, `--kv-overlap-score-weight`, `--router-temperature`, `--no-kv-events`, `--router-replica-sync`, `--router-snapshot-threshold`, `--router-reset-states`, and `--no-track-active-blocks`, see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md).
## Architecture
The standalone router exposes two endpoints via the Dynamo runtime:
1. **`find_best_worker`**: Given a request with token IDs, returns the best worker to handle it
2. **`free`**: Cleans up router state when a request completes
Clients query the `find_best_worker` endpoint to determine which worker should process each request, then call the selected worker directly.
## Example: Disaggregated Serving with Prefill Workers
See [`components/backends/vllm/launch/disagg_router.sh`](../backends/vllm/launch/disagg_router.sh) for a complete example.
```bash
# Start frontend router for decode workers
python -m dynamo.frontend \
--router-mode kv \
--http-port 8000 \
--kv-overlap-score-weight 0 # Pure load balancing for decode
# Start standalone router for prefill workers
python -m dynamo.router \
--endpoint dynamo.prefill.generate \
--block-size 64 \
--router-reset-states \
--no-track-active-blocks
# Start decode workers
python -m dynamo.vllm --model MODEL_NAME --block-size 64 &
# Start prefill workers
python -m dynamo.vllm --model MODEL_NAME --block-size 64 --is-prefill-worker &
```
>[!Note]
> **Why `--no-track-active-blocks` for prefill routing?**
> Active block tracking is used for load balancing across decode (generation) phases. For prefill-only routing, decode load is not relevant, so disabling this reduces overhead and simplifies the router state.
>
> **Why `--block-size` is required for standalone routers:**
> Unlike the frontend router which can infer block size from the ModelDeploymentCard (MDC) during worker registration, standalone routers cannot access the MDC and must have the block size explicitly specified. This is a work in progress to enable automatic inference.
## Configuration Best Practices
>[!Note]
> **Block Size Matching:**
> The block size must match across:
> - Standalone router (`--block-size`)
> - All worker instances (`--block-size`)
>
> **Endpoint Matching:**
> The `--endpoint` argument must match where your target workers register. For example:
> - vLLM prefill workers: `dynamo.prefill.generate`
> - vLLM decode workers: `dynamo.backend.generate`
> - Custom workers: `<your_namespace>.<your_component>.<your_endpoint>`
## Integration with Backends
To integrate the standalone router with a backend:
1. Clients should query the `router.find_best_worker` endpoint before sending requests
2. Workers should register at the endpoint specified by the `--endpoint` argument
3. Clients should call the `router.free` endpoint when requests complete
See [`components/backends/vllm/src/dynamo/vllm/handlers.py`](../backends/vllm/src/dynamo/vllm/handlers.py) for a reference implementation (search for `prefill_router_client`).
## See Also
- [KV Cache Routing Architecture](../../docs/architecture/kv_cache_routing.md) - Detailed explanation of KV-aware routing
- [Frontend Router](../frontend/README.md) - Main HTTP frontend with integrated routing
- [Router Benchmarking](../../benchmarks/router/README.md) - Performance testing and tuning
...@@ -2,20 +2,19 @@ ...@@ -2,20 +2,19 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Centralized Prefill Router Service Standalone KV Router Service
Usage: python -m dynamo.vllm_prefill_router [args] Usage: python -m dynamo.router --endpoint <namespace.component.endpoint> [args]
This service provides a single KV-aware router for all prefill workers in a This service provides a standalone KV-aware router for any set of workers
disaggregated vLLM deployment. Instead of each decode worker maintaining its own in a Dynamo deployment. It can be used for disaggregated serving (e.g., routing
round-robin client to prefill workers, this service uses KvRouter to make to prefill workers) or any other scenario requiring intelligent KV cache-aware
intelligent routing decisions based on KV cache state. routing decisions.
""" """
import argparse import argparse
import asyncio import asyncio
import logging import logging
import os
from typing import Optional from typing import Optional
import uvloop import uvloop
...@@ -28,42 +27,49 @@ configure_dynamo_logging() ...@@ -28,42 +27,49 @@ configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PrefillRouterHandler: class StandaloneRouterHandler:
"""Handles routing requests to prefill workers using KV-aware routing.""" """Handles routing requests to workers using KV-aware routing."""
def __init__(self, runtime: DistributedRuntime, namespace: str, block_size: int): def __init__(
self,
runtime: DistributedRuntime,
worker_endpoint_path: str,
block_size: int,
kv_router_config: KvRouterConfig,
):
self.runtime = runtime self.runtime = runtime
self.namespace = namespace self.worker_endpoint_path = worker_endpoint_path
self.block_size = block_size self.block_size = block_size
self.kv_router_config = kv_router_config
self.kv_router: Optional[KvRouter] = None self.kv_router: Optional[KvRouter] = None
self.prefill_client: Optional[Client] = None self.worker_client: Optional[Client] = None
async def initialize(self): async def initialize(self):
"""Initialize the KV router for prefill workers.""" """Initialize the KV router for workers."""
try: try:
# Get prefill endpoint # Parse endpoint path (format: namespace.component.endpoint)
prefill_endpoint = ( parts = self.worker_endpoint_path.split(".")
self.runtime.namespace(self.namespace) if len(parts) != 3:
.component("prefill") raise ValueError(
.endpoint("generate") f"Invalid endpoint path format: {self.worker_endpoint_path}. "
"Expected format: namespace.component.endpoint"
)
namespace, component, endpoint = parts
# Get worker endpoint
worker_endpoint = (
self.runtime.namespace(namespace)
.component(component)
.endpoint(endpoint)
) )
self.prefill_client = await prefill_endpoint.client() self.worker_client = await worker_endpoint.client()
# Create KvRouter with specified configuration # Create KvRouter with specified configuration
kv_router_config = KvRouterConfig(
router_track_active_blocks=False, # this won't matter for prefill workers
router_reset_states=True, # reset for now
)
self.kv_router = KvRouter( self.kv_router = KvRouter(
endpoint=prefill_endpoint, endpoint=worker_endpoint,
block_size=self.block_size, block_size=self.block_size,
kv_router_config=kv_router_config, kv_router_config=self.kv_router_config,
)
logger.info(
f"KvRouter initialized for prefill workers with block_size={self.block_size}"
) )
except Exception as e: except Exception as e:
...@@ -72,10 +78,10 @@ class PrefillRouterHandler: ...@@ -72,10 +78,10 @@ class PrefillRouterHandler:
async def find_best_worker(self, request): async def find_best_worker(self, request):
""" """
Find the best prefill worker based on KV cache state. Find the best worker based on KV cache state.
This endpoint is called by decode workers to determine which prefill This endpoint is called by clients to determine which worker
worker should handle a request. should handle a request.
""" """
if self.kv_router is None: if self.kv_router is None:
# Fallback to round-robin if router not initialized # Fallback to round-robin if router not initialized
...@@ -87,19 +93,19 @@ class PrefillRouterHandler: ...@@ -87,19 +93,19 @@ class PrefillRouterHandler:
return return
try: try:
# Get current prefill workers # Get current workers
if self.prefill_client is None: if self.worker_client is None:
yield { yield {
"status": "error", "status": "error",
"message": "Prefill client not initialized", "message": "Worker client not initialized",
} }
return return
instance_ids = self.prefill_client.instance_ids() instance_ids = self.worker_client.instance_ids()
if not instance_ids: if not instance_ids:
yield { yield {
"status": "error", "status": "error",
"message": "No prefill workers available", "message": "No workers available",
} }
return return
...@@ -118,7 +124,7 @@ class PrefillRouterHandler: ...@@ -118,7 +124,7 @@ class PrefillRouterHandler:
best_worker_id, overlap_blocks = await self.kv_router.find_best_match( best_worker_id, overlap_blocks = await self.kv_router.find_best_match(
request_id=request_id, request_id=request_id,
tokens=token_ids, tokens=token_ids,
update_states=True, # Always update states for prefill routing update_states=True, # Always update states for routing
) )
logger.debug( logger.debug(
...@@ -178,15 +184,18 @@ class PrefillRouterHandler: ...@@ -178,15 +184,18 @@ class PrefillRouterHandler:
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Dynamo Prefill Router Service: Centralized KV-aware routing for prefill workers", description="Dynamo Standalone Router Service: Configurable KV-aware routing for any worker endpoint",
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
parser.add_argument( parser.add_argument(
"--namespace", "--endpoint",
type=str, type=str,
default=os.environ.get("DYN_NAMESPACE", "dynamo"), required=True,
help="Dynamo namespace for discovering prefill workers (default: dynamo or DYN_NAMESPACE env var)", help=(
"Full endpoint path for workers in the format namespace.component.endpoint\n"
"(e.g., dynamo.prefill.generate for prefill workers)"
),
) )
parser.add_argument( parser.add_argument(
...@@ -197,11 +206,55 @@ def parse_args(): ...@@ -197,11 +206,55 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--log-level", "--kv-overlap-score-weight",
type=str, type=float,
default="INFO", default=1.0,
choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="KV Router: Weight for overlap score in worker selection. Higher values prioritize KV cache reuse (default: 1.0)",
help="Logging level (default: INFO)", )
parser.add_argument(
"--router-temperature",
type=float,
default=0.0,
help="KV Router: Temperature for worker sampling via softmax. Higher values promote more randomness, and 0 fallbacks to deterministic (default: 0.0)",
)
parser.add_argument(
"--no-kv-events",
action="store_false",
dest="use_kv_events",
default=True,
help="KV Router: Disable KV events. When set, uses ApproxKvRouter for predicting block creation/deletion based only on incoming requests. By default, KV events are enabled.",
)
parser.add_argument(
"--router-replica-sync",
action="store_true",
default=False,
help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state (default: False)",
)
parser.add_argument(
"--router-snapshot-threshold",
type=int,
default=1000000,
help="KV Router: Number of messages in stream before triggering a snapshot (default: 1000000)",
)
parser.add_argument(
"--router-reset-states",
action="store_true",
dest="router_reset_states",
default=False,
help="KV Router: Reset router state on startup, purging stream and object store. By default, states are persisted. WARNING: This can affect existing router replicas (default: False)",
)
parser.add_argument(
"--no-track-active-blocks",
action="store_false",
dest="router_track_active_blocks",
default=True,
help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing (default: True)",
) )
return parser.parse_args() return parser.parse_args()
...@@ -209,22 +262,49 @@ def parse_args(): ...@@ -209,22 +262,49 @@ def parse_args():
@dynamo_worker(static=False) @dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
"""Main worker function for the prefill router service.""" """Main worker function for the standalone router service."""
args = parse_args() args = parse_args()
# Set logging level # Parse endpoint path to get namespace for service registration
logging.getLogger().setLevel(getattr(logging, args.log_level)) endpoint_parts = args.endpoint.split(".")
if len(endpoint_parts) != 3:
raise ValueError(
f"Invalid endpoint path format: {args.endpoint}. "
"Expected format: namespace.component.endpoint"
)
namespace = endpoint_parts[0]
logger.info("Starting Standalone Router Service")
logger.debug(
f"Configuration: endpoint={args.endpoint}, block_size={args.block_size}, "
f"overlap_score_weight={args.kv_overlap_score_weight}, "
f"router_temperature={args.router_temperature}, "
f"use_kv_events={args.use_kv_events}, "
f"router_replica_sync={args.router_replica_sync}, "
f"router_reset_states={args.router_reset_states}, "
f"router_track_active_blocks={args.router_track_active_blocks}"
)
logger.info(f"Starting Prefill Router Service for namespace: {args.namespace}") # Create KvRouter configuration
logger.debug(f"Configuration: block_size={args.block_size}") kv_router_config = KvRouterConfig(
overlap_score_weight=args.kv_overlap_score_weight,
router_temperature=args.router_temperature,
use_kv_events=args.use_kv_events,
router_replica_sync=args.router_replica_sync,
router_snapshot_threshold=args.router_snapshot_threshold,
router_reset_states=args.router_reset_states,
router_track_active_blocks=args.router_track_active_blocks,
)
# Create service component # Create service component - use "router" as component name
component = runtime.namespace(args.namespace).component("prefill_router") component = runtime.namespace(namespace).component("router")
await component.create_service() await component.create_service()
# Create handler # Create handler
handler = PrefillRouterHandler(runtime, args.namespace, args.block_size) handler = StandaloneRouterHandler(
runtime, args.endpoint, args.block_size, kv_router_config
)
await handler.initialize() await handler.initialize()
# Expose endpoints # Expose endpoints
...@@ -238,23 +318,23 @@ async def worker(runtime: DistributedRuntime): ...@@ -238,23 +318,23 @@ async def worker(runtime: DistributedRuntime):
find_best_worker_endpoint.serve_endpoint( find_best_worker_endpoint.serve_endpoint(
handler.find_best_worker, handler.find_best_worker,
graceful_shutdown=True, graceful_shutdown=True,
metrics_labels=[("service", "prefill_router")], metrics_labels=[("service", "router")],
), ),
free_endpoint.serve_endpoint( free_endpoint.serve_endpoint(
handler.free, handler.free,
graceful_shutdown=True, graceful_shutdown=True,
metrics_labels=[("service", "prefill_router")], metrics_labels=[("service", "router")],
), ),
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to serve endpoint: {e}") logger.error(f"Failed to serve endpoint: {e}")
raise raise
finally: finally:
logger.info("Prefill Router Service shutting down") logger.info("Standalone Router Service shutting down")
def main(): def main():
"""Entry point for the prefill router service.""" """Entry point for the standalone router service."""
uvloop.run(worker()) uvloop.run(worker())
......
...@@ -29,6 +29,8 @@ The main KV-aware routing arguments: ...@@ -29,6 +29,8 @@ The main KV-aware routing arguments:
- `--router-snapshot-threshold`: Sets the number of messages in the JetStream before triggering a snapshot. When the message count exceeds this threshold, a router will attempt to purge acknowledged messages from the stream and create a snapshot of the current radix tree state in NATs object store. Defaults to 1000000. This helps manage stream size and provides faster initialization for routers that restart. - `--router-snapshot-threshold`: Sets the number of messages in the JetStream before triggering a snapshot. When the message count exceeds this threshold, a router will attempt to purge acknowledged messages from the stream and create a snapshot of the current radix tree state in NATs object store. Defaults to 1000000. This helps manage stream size and provides faster initialization for routers that restart.
- `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management.
>[!Note] >[!Note]
> State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported. > State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported.
> >
......
...@@ -834,6 +834,38 @@ impl SpecDecodeStats { ...@@ -834,6 +834,38 @@ impl SpecDecodeStats {
} }
} }
/// Helper function to create a KV router from an endpoint using the ModelManager
/// to ensure proper etcd registration
async fn create_kv_router_from_endpoint(
endpoint: &Endpoint,
block_size: usize,
kv_router_config: Option<llm_rs::kv_router::KvRouterConfig>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
// Get component from endpoint
let component = endpoint.inner.component();
// Verify we're not in static mode
if component.drt().primary_lease().is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Failed to get primary lease: Cannot KV route static workers",
));
}
// Create ModelManager and use it to create KvRouter (ensures etcd registration)
let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
let kv_router = model_manager
.kv_chooser_for(
"dummy_name", // does not matter, never cached
component,
block_size as u32,
kv_router_config,
)
.await
.map_err(to_pyerr)?;
Ok(kv_router)
}
#[pyclass] #[pyclass]
pub(crate) struct KvRouter { pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>, inner: Arc<llm_rs::kv_router::KvRouter>,
...@@ -842,40 +874,22 @@ pub(crate) struct KvRouter { ...@@ -842,40 +874,22 @@ pub(crate) struct KvRouter {
#[pymethods] #[pymethods]
impl KvRouter { impl KvRouter {
#[new] #[new]
#[pyo3(signature = (endpoint, block_size, kv_router_config=None, consumer_uuid=None))] #[pyo3(signature = (endpoint, block_size, kv_router_config=None))]
fn new( fn new(
endpoint: &Endpoint, endpoint: &Endpoint,
block_size: usize, block_size: usize,
kv_router_config: Option<&super::entrypoint::KvRouterConfig>, kv_router_config: Option<&super::entrypoint::KvRouterConfig>,
consumer_uuid: Option<String>,
) -> PyResult<Self> { ) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async move { runtime.block_on(async move {
// Get component from endpoint let kv_router = create_kv_router_from_endpoint(
let component = endpoint.inner.component(); endpoint,
block_size,
// Verify we're not in static mode
if component.drt().primary_lease().is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Failed to get primary lease: Cannot KV route static workers",
));
}
// Create KvRouter with provided or generated consumer UUID
let consumer_uuid = consumer_uuid.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let kv_router = llm_rs::kv_router::KvRouter::new(
component.clone(),
block_size as u32,
None, // default selector
kv_router_config.map(|c| c.inner()), kv_router_config.map(|c| c.inner()),
consumer_uuid,
) )
.await .await?;
.map_err(to_pyerr)?;
Ok(Self { Ok(Self { inner: kv_router })
inner: Arc::new(kv_router),
})
}) })
} }
...@@ -1040,27 +1054,13 @@ impl KvPushRouter { ...@@ -1040,27 +1054,13 @@ impl KvPushRouter {
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
// Get component from endpoint // Create KvRouter using helper function (ensures etcd registration)
let component = endpoint.inner.component(); let kv_router = create_kv_router_from_endpoint(
endpoint,
// Verify we're not in static mode block_size,
if component.drt().primary_lease().is_none() { Some(kv_router_config.inner()),
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>( )
"Failed to get primary lease: Cannot KV route static workers", .await?;
));
}
// Create ModelManager and use it to create KvRouter (ensures etcd registration)
let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
let kv_router = model_manager
.kv_chooser_for(
"dummy_name", // does not matter, never cached
component,
block_size as u32,
Some(kv_router_config.inner()),
)
.await
.map_err(to_pyerr)?;
// Create KvPushRouter (kv_router is already Arc<KvRouter>) // Create KvPushRouter (kv_router is already Arc<KvRouter>)
let kv_push_router = llm_rs::kv_router::KvPushRouter::new(push_router, kv_router); let kv_push_router = llm_rs::kv_router::KvPushRouter::new(push_router, kv_router);
......
...@@ -79,6 +79,7 @@ path = "hatch_build.py" ...@@ -79,6 +79,7 @@ path = "hatch_build.py"
packages = [ packages = [
"components/frontend/src/dynamo", "components/frontend/src/dynamo",
"components/planner/src/dynamo", "components/planner/src/dynamo",
"components/router/src/dynamo",
"components/backends/llama_cpp/src/dynamo", "components/backends/llama_cpp/src/dynamo",
"components/backends/mocker/src/dynamo", "components/backends/mocker/src/dynamo",
"components/backends/trtllm/src/dynamo", "components/backends/trtllm/src/dynamo",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment