Unverified Commit 0b33c1df authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: sglang disagg routing fixes and optimizations [DYN-1692] (#5106)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarIshan Dhanani <ishandhanani@gmail.com>
Co-authored-by: default avatarSean SH Choi <sechoi@nvidia.com>
Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent e49834c9
......@@ -84,7 +84,6 @@ def get_aiperf_cmd(
str(num_prefix_prompts),
"--artifact-dir",
artifact_dir,
"-v",
"-H",
"Authorization: Bearer NOT USED",
"-H",
......
......@@ -55,7 +55,6 @@ def get_aiperf_cmd_for_trace(
str(seed),
"--artifact-dir",
artifact_dir,
"-v",
"-H",
"Authorization: Bearer NOT USED",
"-H",
......
......@@ -161,8 +161,8 @@ def parse_args():
parser.add_argument(
"--router-max-tree-size",
type=int,
default=int(os.environ.get("DYN_ROUTER_MAX_TREE_SIZE", str(2**10))),
help="KV Router: Maximum tree size before pruning when KV events are disabled. Only used when --no-kv-events is set. Can be set via DYN_ROUTER_MAX_TREE_SIZE env var (default: 1024).",
default=int(os.environ.get("DYN_ROUTER_MAX_TREE_SIZE", str(2**20))),
help="KV Router: Maximum tree size before pruning when KV events are disabled. Only used when --no-kv-events is set. Can be set via DYN_ROUTER_MAX_TREE_SIZE env var (default: 1048576, which is 2^20).",
)
parser.add_argument(
"--router-prune-target-ratio",
......
......@@ -237,8 +237,8 @@ def parse_args():
parser.add_argument(
"--router-max-tree-size",
type=int,
default=2**10,
help="KV Router: Maximum tree size before pruning. Only used when --no-kv-events is set. When the indexer tree exceeds this size, pruning is triggered (default: 1024)",
default=2**20,
help="KV Router: Maximum tree size before pruning. Only used when --no-kv-events is set. When the indexer tree exceeds this size, pruning is triggered (default: 1048576, which is 2^20)",
)
parser.add_argument(
......
......@@ -137,17 +137,6 @@ async def init(runtime: DistributedRuntime, config: Config):
"Registered engine routes: /engine/start_profile, /engine/stop_profile"
)
# Create prefill client for disaggregated decode mode (fallback when --router-mode kv is not used)
prefill_client = None
if config.serving_mode == DisaggregationMode.DECODE:
logging.info("Initializing prefill client for disaggregated decode worker")
prefill_client = (
await runtime.namespace(dynamo_args.namespace)
.component("prefill")
.endpoint("generate")
.client()
)
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
......@@ -160,7 +149,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = DecodeWorkerHandler(component, engine, config, publisher, prefill_client)
handler = DecodeWorkerHandler(component, engine, config, publisher)
print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
......
......@@ -15,7 +15,7 @@ import sglang as sgl
from sglang.srt.tracing import trace as sglang_trace
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Client, Component, Context
from dynamo._core import Component, Context
from dynamo.common.utils.input_params import InputParamManager
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
......@@ -30,7 +30,6 @@ class BaseWorkerHandler(ABC):
engine: sgl.Engine,
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
prefill_client: Optional[Client] = None,
) -> None:
"""Initialize base worker handler.
......@@ -39,7 +38,6 @@ class BaseWorkerHandler(ABC):
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
"""
self.component = component
self.engine = engine
......@@ -50,7 +48,6 @@ class BaseWorkerHandler(ABC):
else:
self.metrics_publisher = None
self.kv_publisher = None
self.prefill_client = prefill_client
self.serving_mode = config.serving_mode
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
self.enable_trace = config.server_args.enable_trace
......
......@@ -4,13 +4,12 @@
import asyncio
import logging
import time
from typing import Any, AsyncGenerator, Dict, Optional
from typing import Any, AsyncGenerator, Dict
import sglang as sgl
from dynamo._core import Client, Component, Context
from dynamo._core import Component, Context
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import DisaggPreprocessedRequest
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......@@ -24,7 +23,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher,
prefill_client: Optional[Client] = None,
) -> None:
"""Initialize decode worker handler.
......@@ -33,14 +31,12 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
"""
super().__init__(
component,
engine,
config,
publisher,
prefill_client,
)
if self.serving_mode == DisaggregationMode.DECODE:
logging.info(
......@@ -108,52 +104,16 @@ class DecodeWorkerHandler(BaseWorkerHandler):
input_param = self._get_input_param(request)
if self.serving_mode == DisaggregationMode.DECODE:
# Check if bootstrap_info is pre-computed in the request (from frontend with --router-mode kv)
# Check if bootstrap_info is pre-computed in the request (from frontend)
bootstrap_info = request.get("bootstrap_info")
if not bootstrap_info:
# Fallback: fetch bootstrap_info from prefill worker via round-robin routing
if self.prefill_client is None:
raise RuntimeError(
"bootstrap_info is required for disaggregated decode but was not provided, "
"and no prefill_client is available for fallback."
"bootstrap_info is required for disaggregated decode but was not provided"
)
logging.debug(
"No bootstrap_info in request, fetching from prefill worker"
)
prefill_stream = await self.prefill_client.generate(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
context=context,
)
prefill_response = None
async for info in prefill_stream:
prefill_response = info.data()
break
if not prefill_response:
raise RuntimeError("No response received from prefill worker")
# Extract bootstrap_info from disaggregated_params (PrefillWorkerHandler format)
bootstrap_info = prefill_response.get("disaggregated_params")
if not bootstrap_info:
raise RuntimeError(
"No bootstrap info (disaggregated_params) received from prefill worker"
)
logging.debug(
f"Received bootstrap_info from prefill worker: "
f"host={bootstrap_info['bootstrap_host']}, "
f"port={bootstrap_info['bootstrap_port']}, "
f"room={bootstrap_info['bootstrap_room']}"
)
else:
logging.debug(
f"Using pre-computed bootstrap_info: "
f"Using bootstrap_info: "
f"host={bootstrap_info['bootstrap_host']}, "
f"port={bootstrap_info['bootstrap_port']}, "
f"room={bootstrap_info['bootstrap_room']}"
......
......@@ -84,11 +84,12 @@ class PrefillWorkerHandler(BaseWorkerHandler):
k: v for k, v in sampling_params.items() if v is not None
}
# Use provided bootstrap_room if available, otherwise generate one
# Use provided bootstrap_room from bootstrap_info if available, otherwise generate one
bootstrap_room = None
extra_args = inner_request.get("extra_args", {})
if isinstance(extra_args, dict):
bootstrap_room = extra_args.get("bootstrap_room")
bootstrap_info_from_req = inner_request.get("bootstrap_info")
if isinstance(bootstrap_info_from_req, dict):
bootstrap_room = bootstrap_info_from_req.get("bootstrap_room")
if bootstrap_room is not None:
logging.debug(f"Using router-provided bootstrap_room: {bootstrap_room}")
if bootstrap_room is None:
......@@ -130,6 +131,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
self._consume_tasks.add(task)
task.add_done_callback(self._consume_tasks.discard)
await task
async def _consume_results(
self, results: AsyncGenerator[Any, None], context: Context
) -> None:
......
......@@ -52,7 +52,7 @@ The main KV-aware routing arguments:
> If you run with `DYN_REQUEST_PLANE=tcp` (or `http`) and KV events enabled (default), you must also configure NATS, e.g. `NATS_SERVER=nats://...`.
> Only `--no-kv-events` removes the NATS requirement.
>
> When `--kv-overlap-score-weight` is set to 0 or `--no-kv-events` is set, no KvIndexer will be launched to drain and process KV events. It's recommended to disable your backend workers from relaying events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases.
> When `--kv-overlap-score-weight` is set to 0, no KvIndexer is created and prefix matching is disabled (pure load balancing). When `--no-kv-events` is set, a KvIndexer is still created but no event subscriber is launched to consume KV events from workers. Instead, the router predicts cache state based on its own routing decisions with TTL-based expiration and pruning. In both cases, it's recommended to disable your backend workers from publishing events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases.
>
> The cli args `--router-ttl`, `--router-max-tree-size`, and `--router-prune-target-ratio` control local cache management when the router operates without receiving events from workers. When KV events are enabled (default), the router relies on worker-side eviction events and these parameters are ignored.
......
......@@ -5,8 +5,8 @@
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 2>/dev/null || true
kill $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 $DECODE_PID2 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 $DECODE_PID2 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
......@@ -26,7 +26,7 @@ while [[ $# -gt 0 ]]; do
echo " -h, --help Show this help message"
echo ""
echo "Note: System metrics are enabled by default on ports:"
echo " 8082-8083 (prefill workers), 8084-8085 (decode workers)"
echo " 8081-8082 (prefill workers), 8083-8084 (decode workers)"
exit 0
;;
*)
......@@ -48,6 +48,7 @@ fi
# Start frontend with KV routing
# The frontend will automatically detect prefill workers and activate an internal prefill router
# No standalone prefill router needed - the frontend handles prefill routing internally
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
OTEL_SERVICE_NAME=dynamo-frontend \
python3 -m dynamo.frontend \
......@@ -55,19 +56,8 @@ python3 -m dynamo.frontend \
--router-reset-states &
DYNAMO_PID=$!
# run prefill router
# Use numeric DYN_SYSTEM_PORT{N} env vars so launchers/test harnesses can set
# ports without encoding role names (prefill/decode) in the env var.
OTEL_SERVICE_NAME=dynamo-router-prefill DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
python3 -m dynamo.router \
--endpoint dynamo.prefill.generate \
--block-size 64 \
--router-reset-states \
--no-track-active-blocks &
PREFILL_ROUTER_PID=$!
# run prefill worker
OTEL_SERVICE_NAME=dynamo-worker-prefill-1 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT2:-8082} \
OTEL_SERVICE_NAME=dynamo-worker-prefill-1 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
......@@ -80,10 +70,10 @@ python3 -m dynamo.sglang \
--disaggregation-transfer-backend nixl \
--enable-metrics \
"${TRACE_ARGS[@]}" &
PREFILL_PID=$!
PREFILL_PID1=$!
# run prefill worker
OTEL_SERVICE_NAME=dynamo-worker-prefill-2 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT3:-8083} \
OTEL_SERVICE_NAME=dynamo-worker-prefill-2 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT2:-8082} \
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
......@@ -96,10 +86,10 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--disaggregation-transfer-backend nixl \
--enable-metrics \
"${TRACE_ARGS[@]}" &
PREFILL_PID=$!
PREFILL_PID2=$!
# run decode worker
OTEL_SERVICE_NAME=dynamo-worker-decode-1 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT4:-8084} \
OTEL_SERVICE_NAME=dynamo-worker-decode-1 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT3:-8083} \
CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
......@@ -112,10 +102,10 @@ CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.sglang \
--disaggregation-transfer-backend nixl \
--enable-metrics \
"${TRACE_ARGS[@]}" &
PREFILL_PID=$!
DECODE_PID1=$!
# run decode worker
OTEL_SERVICE_NAME=dynamo-worker-decode-2 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT5:-8085} \
OTEL_SERVICE_NAME=dynamo-worker-decode-2 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT4:-8084} \
CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
......@@ -127,4 +117,8 @@ CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5559"}' \
--disaggregation-transfer-backend nixl \
--enable-metrics \
"${TRACE_ARGS[@]}"
"${TRACE_ARGS[@]}" &
DECODE_PID2=$!
# Wait for any worker to exit (keeps script running)
wait
......@@ -11,7 +11,7 @@ use std::sync::atomic::{AtomicU32, Ordering};
use dynamo_llm::{
discovery::{KvWorkerMonitor, ModelWatcher},
kv_router::{indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher},
kv_router::{protocols::*, publisher::KvEventPublisher},
};
use dynamo_runtime::{DistributedRuntime, Worker};
static WK: OnceCell<Worker> = OnceCell::new();
......@@ -1475,6 +1475,7 @@ pub async fn create_worker_selection_pipeline_chat(
>(
&card_with_local_files,
&client,
model_manager.clone(),
router_mode,
worker_monitor,
chooser,
......
......@@ -1587,6 +1587,7 @@ dependencies = [
"tokio-util",
"tracing",
"url",
"utoipa",
"uuid",
]
......@@ -7792,6 +7793,8 @@ dependencies = [
"quote",
"regex",
"syn 2.0.110",
"url",
"uuid",
]
[[package]]
......
......@@ -50,7 +50,7 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1024, router_prune_target_ratio=0.8))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
......
......@@ -11,11 +11,11 @@ use tokio_stream::StreamExt;
use super::*;
use crate::Component;
use llm_rs::kv_router::indexer::KvIndexerInterface;
use llm_rs::kv_router::indexer::compute_block_hash_for_seq;
use llm_rs::kv_router::protocols::ForwardPassMetrics as RsForwardPassMetrics;
use llm_rs::kv_router::protocols::KvStats as RsKvStats;
use llm_rs::kv_router::protocols::SpecDecodeStats as RsSpecDecodeStats;
use llm_rs::kv_router::protocols::WorkerStats as RsWorkerStats;
use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
use rs::pipeline::{AsyncEngine, SingleIn};
use rs::traits::events::EventSubscriber;
use tracing;
......@@ -782,7 +782,7 @@ pub(crate) struct ApproxKvIndexer {
#[pymethods]
impl ApproxKvIndexer {
#[new]
#[pyo3(signature = (component, kv_block_size, router_ttl_secs=120.0, router_max_tree_size=1024, router_prune_target_ratio=0.8))]
#[pyo3(signature = (component, kv_block_size, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
fn new(
component: Component,
kv_block_size: usize,
......@@ -851,10 +851,12 @@ impl ApproxKvIndexer {
dp_rank: DpRank,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
let block_size = self.inner.block_size();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker = llm_rs::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank);
let mut tokens_with_hashes = TokensWithHashes::new(tokens, block_size);
indexer
.process_routing_decision_for_request(tokens.as_slice(), worker)
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
.map_err(to_pyerr)?;
Ok(())
......
......@@ -684,7 +684,7 @@ class ApproxKvIndexer:
component: Component,
kv_block_size: int,
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1024,
router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8,
) -> None:
"""
......@@ -694,7 +694,7 @@ class ApproxKvIndexer:
component: The component to associate with this indexer
kv_block_size: The KV cache block size
router_ttl_secs: TTL for blocks in seconds (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1024)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
"""
...
......@@ -1091,7 +1091,7 @@ class KvRouterConfig:
router_snapshot_threshold: Optional[int] = 1000000,
router_reset_states: bool = False,
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1024,
router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8,
) -> None:
"""
......@@ -1106,7 +1106,7 @@ class KvRouterConfig:
router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
router_reset_states: Reset router state on startup (default: False)
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1024)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
"""
...
......
......@@ -6,6 +6,7 @@ use std::{
sync::Arc,
};
use dashmap::{DashMap, mapref::entry::Entry};
use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;
......@@ -13,13 +14,17 @@ use crate::discovery::KvWorkerMonitor;
use dynamo_runtime::{
component::{Client, Endpoint, build_transport_type},
discovery::DiscoverySpec,
discovery::{DiscoveryQuery, DiscoverySpec, watch_and_extract_field},
prelude::DistributedRuntimeProvider,
protocols::EndpointId,
};
use crate::{
kv_router::{KvRouter, KvRouterConfig, router_endpoint_id, scheduler::DefaultWorkerSelector},
kv_router::{
KvRouter, KvRouterConfig, protocols::WorkerId, router_endpoint_id,
scheduler::DefaultWorkerSelector,
},
local_model::runtime_config::{DisaggregatedEndpoint, ModelRuntimeConfig},
model_card::ModelDeploymentCard,
model_type::ModelType,
types::{
......@@ -73,6 +78,11 @@ pub struct ModelManager {
/// Key: model name, Value: cloneable monitor (all fields are Arc).
/// HTTP endpoint can update thresholds via monitor.set_threshold().
worker_monitors: RwLock<HashMap<String, KvWorkerMonitor>>,
/// Runtime configs per endpoint using DashMap for lock-free access.
/// Outer DashMap: keyed by EndpointId
/// Inner Arc<DashMap>: keyed by WorkerId, shared with KvScheduler
runtime_configs: DashMap<EndpointId, Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>>,
}
impl Default for ModelManager {
......@@ -93,6 +103,7 @@ impl ModelManager {
kv_choosers: Mutex::new(HashMap::new()),
prefill_router_activators: Mutex::new(HashMap::new()),
worker_monitors: RwLock::new(HashMap::new()),
runtime_configs: DashMap::new(),
}
}
......@@ -360,10 +371,14 @@ impl ModelManager {
// Use instance_id (hex) as the consumer ID for NATS consumer coordination
let consumer_id = instance_id.to_string();
// Get or create runtime config watcher for this endpoint
let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new(
endpoint.clone(),
client,
workers_with_configs,
kv_cache_block_size,
Some(selector),
kv_router_config,
......@@ -602,6 +617,146 @@ impl ModelManager {
self.worker_monitors.read().get(model).cloned()
}
/// Get or create a runtime config watcher for an endpoint.
/// Spawns a background task to watch DiscoveryQuery::EndpointModels.
/// Returns a shared Arc<DashMap> that KvScheduler can use directly.
pub async fn get_or_create_runtime_config_watcher(
&self,
endpoint: &Endpoint,
) -> anyhow::Result<Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>> {
let endpoint_id = endpoint.id();
// Fast path: return existing if present
if let Some(existing) = self.runtime_configs.get(&endpoint_id) {
return Ok(existing.clone());
}
// Atomic get-or-insert to avoid TOCTOU race
let inner_map = Arc::new(DashMap::new());
let (map, is_new) = match self.runtime_configs.entry(endpoint_id) {
Entry::Occupied(e) => (e.get().clone(), false),
Entry::Vacant(e) => {
e.insert(inner_map.clone());
(inner_map, true)
}
};
// Only spawn watcher if we were the one who inserted
if is_new {
self.spawn_runtime_config_watcher(endpoint, map.clone())
.await?;
}
Ok(map)
}
/// Get disaggregated endpoint for a specific worker.
/// Used by PrefillRouter for bootstrap info - works for ANY routing mode.
pub fn get_disaggregated_endpoint(
&self,
endpoint_id: &EndpointId,
worker_id: WorkerId,
) -> Option<DisaggregatedEndpoint> {
let inner_map = self.runtime_configs.get(endpoint_id)?;
let config_ref = inner_map.get(&worker_id)?;
config_ref.as_ref()?.disaggregated_endpoint.clone()
}
/// Spawn background task to watch runtime configs via discovery.
async fn spawn_runtime_config_watcher(
&self,
endpoint: &Endpoint,
inner_map: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
) -> anyhow::Result<()> {
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
// Set up discovery watch for EndpointModels
let discovery = component.drt().discovery();
let endpoint_id = endpoint.id();
let discovery_key = DiscoveryQuery::EndpointModels {
namespace: endpoint_id.namespace.clone(),
component: endpoint_id.component.clone(),
endpoint: endpoint_id.name.clone(),
};
let discovery_stream = discovery
.list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
.await?;
// Extract runtime_config from ModelDeploymentCard
let mut runtime_configs_rx =
watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
card.runtime_config
});
// Also watch instance IDs
let client = endpoint.client().await?;
let mut instance_ids_rx = client.instance_avail_watcher();
// Spawn background task to update inner_map
let cancel_token = cancellation_token.clone();
tokio::spawn(async move {
tracing::trace!("ModelManager runtime config watcher started");
loop {
// Wait for either instances or configs to change
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!("ModelManager runtime config watcher shutting down");
break;
}
result = instance_ids_rx.changed() => {
if result.is_err() {
tracing::warn!("instance IDs watch sender shutdown in ModelManager");
break;
}
}
result = runtime_configs_rx.changed() => {
if result.is_err() {
tracing::warn!("runtime configs watch sender shutdown in ModelManager");
break;
}
}
}
// Get the latest values from both channels
let new_instance_ids = instance_ids_rx.borrow_and_update().clone();
let new_configs = runtime_configs_rx.borrow_and_update().clone();
// Update the DashMap
// First, remove workers that no longer exist
let current_workers: HashSet<WorkerId> =
inner_map.iter().map(|r| *r.key()).collect();
let new_workers: HashSet<WorkerId> = new_instance_ids.iter().copied().collect();
for removed_worker in current_workers.difference(&new_workers) {
inner_map.remove(removed_worker);
}
// Then, add/update workers
for worker_id in &new_instance_ids {
let config = new_configs.get(worker_id).cloned();
if config.is_some() {
let prev_config = inner_map.get(worker_id);
if prev_config.as_ref().map(|r| r.value()) != Some(&config) {
tracing::info!(
"ModelManager: Runtime config found for worker_id: {}",
worker_id
);
}
}
inner_map.insert(*worker_id, config);
}
tracing::trace!(
"ModelManager: Updated runtime_configs with {} workers",
inner_map.len()
);
}
tracing::trace!("ModelManager runtime config watcher shutting down");
});
Ok(())
}
/// Lists all models that have worker monitors (and thus busy thresholds) configured.
///
/// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
......
......@@ -449,6 +449,7 @@ impl ModelWatcher {
>(
card,
&client,
self.manager.clone(),
self.router_config.router_mode,
worker_monitor.clone(),
kv_chooser.clone(),
......@@ -482,6 +483,7 @@ impl ModelWatcher {
>(
card,
&client,
self.manager.clone(),
self.router_config.router_mode,
worker_monitor,
kv_chooser,
......
......@@ -172,6 +172,7 @@ where
pub async fn build_routed_pipeline<Req, Resp>(
card: &ModelDeploymentCard,
client: &Client,
model_manager: Arc<crate::discovery::ModelManager>,
router_mode: RouterMode,
worker_monitor: Option<KvWorkerMonitor>,
chooser: Option<Arc<KvRouter>>,
......@@ -196,6 +197,7 @@ where
build_routed_pipeline_with_preprocessor(
card,
client,
model_manager,
router_mode,
worker_monitor,
chooser,
......@@ -212,6 +214,7 @@ where
pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
card: &ModelDeploymentCard,
client: &Client,
model_manager: Arc<crate::discovery::ModelManager>,
router_mode: RouterMode,
worker_monitor: Option<KvWorkerMonitor>,
chooser: Option<Arc<KvRouter>>,
......@@ -277,8 +280,8 @@ where
};
// Use the provided prefill chooser, or create a disabled one if not provided
let prefill_chooser =
prefill_chooser.unwrap_or_else(|| PrefillRouter::disabled(router_mode, enforce_disagg));
let prefill_chooser = prefill_chooser
.unwrap_or_else(|| PrefillRouter::disabled(model_manager, router_mode, enforce_disagg));
let prefill_op = prefill_chooser.into_operator();
// Link with prefill chooser including backward edge for response flow
......
......@@ -6,6 +6,7 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use dashmap::DashMap;
use derive_builder::Builder;
use dynamo_runtime::{
component::{Client, Endpoint},
......@@ -40,13 +41,11 @@ use worker_query::WorkerQueryClient;
use crate::{
kv_router::{
approx::PruneConfig,
indexer::{
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
compute_block_hash_for_seq, compute_seq_hash_for_block,
},
indexer::{KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent},
protocols::{
LocalBlockHash, RouterRequest, RouterResponse, WorkerId, WorkerSelectionResult,
WorkerWithDpRank,
LocalBlockHash, RouterRequest, RouterResponse, TokensWithHashes, WorkerId,
WorkerSelectionResult, WorkerWithDpRank, compute_block_hash_for_seq,
compute_seq_hash_for_block,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
sequence::SequenceError,
......@@ -57,7 +56,6 @@ use crate::{
preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput,
protocols::common::timing::RequestPhase,
tokens::SequenceHash,
};
// [gluo TODO] shouldn't need to be public
......@@ -148,7 +146,7 @@ pub struct KvRouterConfig {
/// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
pub router_ttl_secs: f64,
/// Maximum tree size before pruning (only used when use_kv_events is false, default: 1024)
/// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
pub router_max_tree_size: usize,
/// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
......@@ -166,7 +164,7 @@ impl Default for KvRouterConfig {
router_snapshot_threshold: Some(1000000),
router_reset_states: false,
router_ttl_secs: 120.0,
router_max_tree_size: 1024,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8,
}
}
......@@ -244,16 +242,15 @@ impl Indexer {
}
}
async fn process_routing_decision(
async fn process_routing_decision_for_request(
&self,
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => {
indexer
.process_routing_decision(worker, local_hashes, sequence_hashes)
.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
Indexer::None => Ok(()),
......@@ -284,6 +281,7 @@ impl KvRouter {
pub async fn new(
endpoint: Endpoint,
client: Client,
workers_with_configs: Arc<DashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>>,
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: Option<KvRouterConfig>,
......@@ -296,6 +294,7 @@ impl KvRouter {
let instance_ids_rx = client.instance_avail_watcher();
// Watch for runtime config updates via discovery interface
// (still needed for WorkerQueryClient and background tasks)
let discovery = component.drt().discovery();
let endpoint_id = endpoint.id();
let discovery_key = DiscoveryQuery::EndpointModels {
......@@ -341,7 +340,7 @@ impl KvRouter {
component.clone(),
block_size,
instance_ids_rx,
runtime_configs_rx.clone(),
workers_with_configs,
selector,
kv_router_config.router_replica_sync,
consumer_id.clone(),
......@@ -476,41 +475,28 @@ impl KvRouter {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
// Determine who needs seq_hashes
let needs_process_routing = !self.kv_router_config.use_kv_events;
let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;
// Optimize cloning: only clone if both need it, otherwise move
let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
match (needs_process_routing, scheduler_needs_it) {
(true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
(true, false) => (Some(seq_hashes), None),
(false, true) => (None, Some(seq_hashes)),
(false, false) => (None, None),
};
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = self
.kv_router_config
.router_track_active_blocks
.then(|| compute_seq_hash_for_block(&block_hashes));
let best_worker = self
.scheduler
.schedule(
context_id.map(|s| s.to_string()),
isl_tokens,
maybe_seq_hashes_2,
maybe_seq_hashes,
overlap_scores.clone(),
router_config_override,
update_states,
)
.await?;
// Process routing decision when not using KV events (approximate mode with TTL/pruning)
if needs_process_routing {
self.indexer
.process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
.await?;
}
// Note: Routing decision recording (for approximate mode) is now handled
// by KvPushRouter::generate after select_worker returns.
let overlap_amount = overlap_scores
.scores
......@@ -573,25 +559,16 @@ impl KvRouter {
Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
}
/// Get the disaggregated endpoint for a worker, if available.
/// Used to look up bootstrap host/port for prefill workers.
pub async fn get_disaggregated_endpoint(
&self,
worker_id: u64,
) -> Option<crate::local_model::runtime_config::DisaggregatedEndpoint> {
self.scheduler.get_disaggregated_endpoint(worker_id).await
}
/// Get potential prefill and decode loads for all workers
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
compute_seq_hash_for_block(&block_hashes)
});
let maybe_seq_hashes = self
.kv_router_config
.router_track_active_blocks
.then(|| compute_seq_hash_for_block(&block_hashes));
Ok(self
.scheduler
......@@ -838,14 +815,16 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let is_query_only = request.get_annotation_value("query_instance_id").is_some();
// Determine if this router should handle local state updates (add_request, free, etc.)
// When routing hints are present, the external caller handles state tracking
// via separate API calls, so we skip local updates here.
// Only skip local updates for GAIE Stage 2: when BOTH prefill and decode worker IDs
// are externally specified (indicates external orchestrator handles tracking).
// For internal routing (e.g., bootstrap optimization with only prefill_worker_id set),
// we still handle updates locally.
let routing = request.routing.as_ref();
let handle_local_updates = routing
.map(|r| {
// No routing hints = we handle updates locally
r.backend_instance_id.is_none()
&& (r.prefill_worker_id.is_none() || r.decode_worker_id.is_none())
// GAIE Stage 2 sets both worker IDs - external caller handles tracking
// All other cases (including backend_instance_id for routing) - we handle locally
r.prefill_worker_id.is_none() || r.decode_worker_id.is_none()
})
.unwrap_or(true);
......@@ -872,6 +851,29 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
overlap_amount,
} = selection;
// In approximate mode (use_kv_events=false), record the routing decision
// so the indexer can track cache state based on routing decisions.
// This covers both pre-selected workers and find_best_match selections.
if !is_query_only && !self.chooser.kv_router_config.use_kv_events {
let worker = WorkerWithDpRank::new(instance_id, dp_rank);
let mut tokens_with_hashes =
TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size);
if let Err(e) = self
.chooser
.indexer
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
{
tracing::warn!(
request_id = %context_id,
worker_id = instance_id,
dp_rank = dp_rank,
error = %e,
"Failed to record routing decision in approximate mode"
);
}
}
// Record metrics in tracker: KV hit rate and worker ID based on phase
if let Some(ref tracker) = request.tracker {
let isl_blocks = request.token_ids.len().div_ceil(block_size);
......@@ -936,11 +938,18 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
};
if !prefill_marked {
// Only mark prefill completed when we receive actual tokens,
// not empty bootstrap info (token_ids: []) from disaggregated prefill
let has_tokens = item.data.as_ref()
.map(|d| !d.token_ids.is_empty())
.unwrap_or(false);
if has_tokens {
if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
}
prefill_marked = true;
}
}
yield item;
}
......
......@@ -223,7 +223,7 @@ impl<K: Clone + Hash + Eq + Ord> PruneManager<K> {
mod tests {
use super::*;
use crate::kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
use crate::kv_router::protocols::{WorkerId, WorkerWithDpRank};
use crate::kv_router::protocols::{TokensWithHashes, WorkerId, WorkerWithDpRank};
use std::sync::Arc;
use tokio::time::{self, Duration, Instant};
use tokio_util::sync::CancellationToken;
......@@ -355,9 +355,10 @@ mod tests {
assert!(pre_scores.scores.is_empty());
// 2. Inform indexer about routing decision
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer
.process_routing_decision_for_request(
&tokens,
&mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_id),
)
.await
......@@ -401,9 +402,10 @@ mod tests {
let tokens: Vec<u32> = vec![10, 11, 12, 13];
let worker_id: WorkerId = 7;
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer
.process_routing_decision_for_request(
&tokens,
&mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_id),
)
.await
......@@ -454,16 +456,18 @@ mod tests {
let worker_1: WorkerId = 31;
// Register on both workers
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer
.process_routing_decision_for_request(
&tokens,
&mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_0),
)
.await
.unwrap();
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer
.process_routing_decision_for_request(
&tokens,
&mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_1),
)
.await
......@@ -524,9 +528,10 @@ mod tests {
let worker_a: WorkerId = 11;
// Register Sequence A on worker A
let mut tokens_with_hashes = TokensWithHashes::new(seq_a.clone(), KV_BLOCK_SIZE);
indexer
.process_routing_decision_for_request(
&seq_a,
&mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_a),
)
.await
......@@ -582,16 +587,18 @@ mod tests {
let worker_1: WorkerId = 22;
// Register the same sequence on two different workers
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer
.process_routing_decision_for_request(
&tokens,
&mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_0),
)
.await
.unwrap();
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer
.process_routing_decision_for_request(
&tokens,
&mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_1),
)
.await
......@@ -759,8 +766,9 @@ mod tests {
// Insert 5 sequences (5 blocks total, at max_tree_size but not exceeding)
for i in 0..5 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let mut tokens_with_hashes = TokensWithHashes::new(tokens, KV_BLOCK_SIZE);
indexer
.process_routing_decision_for_request(&tokens, worker)
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
.unwrap();
time::sleep(Duration::from_millis(1)).await; // Ensure different timestamps
......@@ -780,8 +788,9 @@ mod tests {
// Insert 6th block - this exceeds max_tree_size and should trigger reactive pruning
let tokens: Vec<u32> = vec![50, 51, 52, 53];
let mut tokens_with_hashes = TokensWithHashes::new(tokens, KV_BLOCK_SIZE);
indexer
.process_routing_decision_for_request(&tokens, worker)
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
.unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment