Unverified Commit a2154ba5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(mocker): batch live output signal sends (#7647)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 06f17011
......@@ -35,7 +35,6 @@ _KV_ROUTER_FIELDS: tuple[str, ...] = (
"router_queue_threshold",
"router_event_threads",
"router_enable_cache_control",
"min_initial_workers",
"router_queue_policy",
"remote_indexer_component",
)
......@@ -61,7 +60,6 @@ class KvRouterConfigBase(ConfigBase):
router_queue_threshold: Optional[float]
router_event_threads: int
router_enable_cache_control: bool
min_initial_workers: int
router_queue_policy: str
remote_indexer_component: Optional[str]
......@@ -274,18 +272,6 @@ class KvRouterArgGroup(ArgGroup):
"requests with nvext.cache_control."
),
)
add_argument(
g,
flag_name="--router-min-initial-workers",
env_var="DYN_ROUTER_MIN_INITIAL_WORKERS",
default=1,
help=(
"KV Router: Minimum number of workers that must be discovered before "
"router startup continues. Ignored when skip_initial_worker_wait is enabled."
),
arg_type=int,
dest="min_initial_workers",
)
add_argument(
g,
flag_name="--router-queue-policy",
......
......@@ -50,6 +50,7 @@ class FrontendConfig(KvRouterConfigBase):
tls_key_path: Optional[pathlib.Path]
router_mode: str
min_initial_workers: int
namespace: Optional[str] = None
namespace_prefix: Optional[str] = None
enforce_disagg: bool
......@@ -90,6 +91,8 @@ class FrontendConfig(KvRouterConfigBase):
raise ValueError(
"--migration-limit must be between 0 and 4294967295 (0=disabled)"
)
if self.min_initial_workers < 0:
raise ValueError("--router-min-initial-workers must be >= 0")
if self.router_enable_cache_control and self.router_mode != "kv":
raise ValueError("--enable-cache-control requires --router-mode=kv")
if self.tokenizer_backend not in self._VALID_TOKENIZER_BACKENDS:
......@@ -188,6 +191,20 @@ class FrontendArgGroup(ArgGroup):
"synchronous prefill path.",
choices=["round-robin", "random", "power-of-two", "kv", "direct"],
)
add_argument(
g,
flag_name="--router-min-initial-workers",
env_var="DYN_ROUTER_MIN_INITIAL_WORKERS",
default=0,
help=(
"Minimum number of workers required before router startup continues. "
"This is exported as DYN_ROUTER_MIN_INITIAL_WORKERS so the generic "
"push-router path and the KV router's config-ready worker gate share "
"the same startup threshold. Set to 0 to disable the startup wait."
),
arg_type=int,
dest="min_initial_workers",
)
# KV router options (shared with dynamo.router)
KvRouterArgGroup().add_arguments(parser)
......
......@@ -47,6 +47,8 @@ if TYPE_CHECKING:
configure_dynamo_logging()
logger = logging.getLogger(__name__)
MIN_INITIAL_WORKERS_ENV = "DYN_ROUTER_MIN_INITIAL_WORKERS"
def setup_engine_factory(
runtime: DistributedRuntime,
......@@ -232,6 +234,7 @@ async def async_main():
router_mode = RouterMode.RoundRobin
kv_router_config = None
os.environ[MIN_INITIAL_WORKERS_ENV] = str(config.min_initial_workers)
router_config = RouterConfig(
router_mode,
kv_router_config,
......
......@@ -60,6 +60,9 @@ async def main():
kv_router_config=kv_router_config
)
# Optional startup gate shared with the frontend and standalone indexer:
# os.environ["DYN_ROUTER_MIN_INITIAL_WORKERS"] = "2"
# Your input tokens
token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
......
......@@ -16,6 +16,10 @@ This is distinct from the [Standalone Router](../../../components/src/dynamo/rou
The HTTP API follows the [Mooncake KV Indexer RFC](https://github.com/kvcache-ai/Mooncake/issues/1403) conventions.
`DYN_ROUTER_MIN_INITIAL_WORKERS` is also honored here. When set to a positive integer, the
standalone indexer waits for that many workers to register before opening its startup-ready
gate, matching the frontend/router startup behavior.
## Multi-Model and Multi-Tenant Support
The indexer maintains one radix tree per `(model_name, tenant_id)` pair. Workers registered with different model names or tenant IDs are isolated into separate indexers — queries against one model/tenant never return scores from another.
......@@ -143,6 +147,12 @@ In runtime mode, workers are discovered automatically via MDC. The `--workers` f
| `--component-name` | `kv-indexer` | Component name for this indexer in the Dynamo runtime |
| `--worker-component` | `backend` | Component name that workers register under for event-plane subscription |
### Shared Startup Gate
Set `DYN_ROUTER_MIN_INITIAL_WORKERS=<n>` to require at least `<n>` workers before the
standalone indexer, frontend push-router path, and KV router config-ready gate all proceed.
Leave it unset or set it to `0` to disable the startup wait.
## HTTP API
### `GET /health` — Liveness check
......
......@@ -378,24 +378,22 @@ async fn apply_entry(
isl,
OverlapScores::default(),
);
let _ = multi
.add_request(SequenceRequest {
request_id,
token_sequence: Some(block_hashes),
isl,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: Some(output_length as u32),
worker,
lora_name: None,
})
.await;
let _ = multi.add_request(SequenceRequest {
request_id,
token_sequence: Some(block_hashes),
isl,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: Some(output_length as u32),
worker,
lora_name: None,
});
}
SequenceTraceEntry::PrefillComplete { request_id } => {
let _ = multi.mark_prefill_completed(&request_id).await;
let _ = multi.mark_prefill_completed(&request_id);
}
SequenceTraceEntry::Free { request_id } => {
let _ = multi.free(&request_id).await;
let _ = multi.free(&request_id);
}
}
}
......
......@@ -389,7 +389,7 @@ async fn replay_worker_trace(
.rescale_ready_span(trace_simulation_duration_ms)?
.into_trace_driver()?;
let collector = EventCollector::new();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler = Scheduler::new(
sched_args,
0,
......@@ -429,32 +429,40 @@ async fn replay_worker_trace(
let deadline = start + Duration::from_secs_f64((next_ready_ms.max(0.0)) / 1000.0);
tokio::select! {
maybe_signal = output_rx.recv() => {
let Some(signal) = maybe_signal else {
let Some(output_batch) = maybe_signal else {
anyhow::bail!("scheduler ended before workload replay drained");
};
output_signals.push(TimedOutputSignal {
signal: signal.clone(),
timestamp_us: start.elapsed().as_micros() as u64,
});
if signal.completed {
completed_turns += 1;
driver.on_complete(signal.uuid, start.elapsed().as_secs_f64() * 1000.0)?;
let timestamp_us = start.elapsed().as_micros() as u64;
let completion_ms = start.elapsed().as_secs_f64() * 1000.0;
for signal in output_batch {
output_signals.push(TimedOutputSignal {
signal: signal.clone(),
timestamp_us,
});
if signal.completed {
completed_turns += 1;
driver.on_complete(signal.uuid, completion_ms)?;
}
}
}
_ = tokio::time::sleep_until(deadline) => {}
}
}
None => {
let Some(signal) = output_rx.recv().await else {
let Some(output_batch) = output_rx.recv().await else {
anyhow::bail!("scheduler ended before workload replay drained");
};
output_signals.push(TimedOutputSignal {
signal: signal.clone(),
timestamp_us: start.elapsed().as_micros() as u64,
});
if signal.completed {
completed_turns += 1;
driver.on_complete(signal.uuid, start.elapsed().as_secs_f64() * 1000.0)?;
let timestamp_us = start.elapsed().as_micros() as u64;
let completion_ms = start.elapsed().as_secs_f64() * 1000.0;
for signal in output_batch {
output_signals.push(TimedOutputSignal {
signal: signal.clone(),
timestamp_us,
});
if signal.completed {
completed_turns += 1;
driver.on_complete(signal.uuid, completion_ms)?;
}
}
}
}
......
......@@ -58,7 +58,7 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_enable_cache_control=false, min_initial_workers=1, router_queue_policy="fcfs", remote_indexer_component=None))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_enable_cache_control=false, router_queue_policy="fcfs", remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
......@@ -78,7 +78,6 @@ impl KvRouterConfig {
router_queue_threshold: Option<f64>,
router_event_threads: u32,
router_enable_cache_control: bool,
min_initial_workers: usize,
router_queue_policy: &str,
remote_indexer_component: Option<String>,
) -> Self {
......@@ -102,7 +101,6 @@ impl KvRouterConfig {
router_event_threads,
router_enable_cache_control,
skip_initial_worker_wait: false,
min_initial_workers,
router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| {
panic!("invalid router_queue_policy: {router_queue_policy:?}")
}),
......
......@@ -1179,7 +1179,6 @@ class KvRouterConfig:
router_queue_threshold: Optional[float] = 4.0,
router_event_threads: int = 4,
router_enable_cache_control: bool = False,
min_initial_workers: int = 1,
router_queue_policy: str = "fcfs",
) -> None:
"""
......@@ -1213,9 +1212,6 @@ class KvRouterConfig:
When > 1, uses a concurrent radix tree with a thread pool.
router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's
cache_control service mesh endpoint (default: False).
min_initial_workers: Minimum number of discovered workers required before
router startup continues (default: 1). Ignored when
skip_initial_worker_wait is enabled.
router_queue_policy: Scheduling policy for the router queue (default: "fcfs").
"fcfs": first-come first-served with priority bumps — optimizes tail TTFT.
"lcfs": last-come first-served with priority bumps — intentionally worsens tail behavior for policy comparisons.
......
......@@ -84,7 +84,6 @@ def _router_config_payload():
"router_prune_target_ratio": 0.8,
"router_enable_cache_control": False,
"skip_initial_worker_wait": False,
"min_initial_workers": 1,
"remote_indexer_component": None,
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::env::{self, VarError};
use std::fmt;
use std::str::FromStr;
......@@ -13,14 +14,26 @@ use crate::protocols::{
BlockHashOptions, LocalBlockHash, compute_block_hash_for_seq, compute_seq_hash_for_block,
};
const fn default_min_initial_workers() -> usize {
1
}
const fn default_track_prefill_tokens() -> bool {
true
}
pub const DYN_ROUTER_MIN_INITIAL_WORKERS: &str = "DYN_ROUTER_MIN_INITIAL_WORKERS";
pub fn min_initial_workers_from_env() -> anyhow::Result<usize> {
match env::var(DYN_ROUTER_MIN_INITIAL_WORKERS) {
Ok(value) => value.parse::<usize>().map_err(|error| {
anyhow::anyhow!(
"{DYN_ROUTER_MIN_INITIAL_WORKERS} must be a non-negative integer, got {value:?}: {error}"
)
}),
Err(VarError::NotPresent) => Ok(0),
Err(VarError::NotUnicode(_)) => {
anyhow::bail!("{DYN_ROUTER_MIN_INITIAL_WORKERS} must be valid unicode")
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RouterQueuePolicy {
......@@ -149,17 +162,8 @@ pub struct KvRouterConfig {
/// When false (default), cache_control is ignored and no cache_control client is created.
pub router_enable_cache_control: bool,
/// Skip blocking for workers at init time (default: false).
/// When true, the router starts immediately without waiting for discovery-based
/// workers and workers are provided externally per-request (e.g., EPP).
pub skip_initial_worker_wait: bool,
/// Minimum number of workers that must be discovered before router startup continues.
/// Default: 1. Ignored when skip_initial_worker_wait=true.
#[serde(default = "default_min_initial_workers")]
#[validate(range(min = 1))]
pub min_initial_workers: usize,
/// Scheduling policy for the router queue.
/// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
/// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
......@@ -194,7 +198,6 @@ impl Default for KvRouterConfig {
router_event_threads: 4,
router_enable_cache_control: false,
skip_initial_worker_wait: false,
min_initial_workers: default_min_initial_workers(),
router_queue_policy: RouterQueuePolicy::default(),
remote_indexer_component: None,
}
......@@ -309,25 +312,11 @@ mod tests {
assert_eq!(deserialized, RouterQueuePolicy::Lcfs);
}
#[test]
fn kv_router_config_defaults_to_one_initial_worker() {
assert_eq!(KvRouterConfig::default().min_initial_workers, 1);
}
#[test]
fn kv_router_config_defaults_to_tracking_prefill_tokens() {
assert!(KvRouterConfig::default().router_track_prefill_tokens);
}
#[test]
fn kv_router_config_rejects_zero_initial_workers() {
let cfg = KvRouterConfig {
min_initial_workers: 0,
..KvRouterConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn compute_seq_hashes_for_tracking_uses_mm_hashes() {
let cfg = KvRouterConfig::default();
......
......@@ -192,19 +192,17 @@ where
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req).await
self.slots.add_request(req)
}
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots
.mark_prefill_completed(&request_id.to_string())
.await?;
self.slots.mark_prefill_completed(&request_id.to_string())?;
self.queue.update().await;
Ok(())
}
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string()).await?;
self.slots.free(&request_id.to_string())?;
self.queue.update().await;
Ok(())
}
......
......@@ -231,20 +231,16 @@ impl<
return;
};
if let Err(e) = self
.slots
.add_request(SequenceRequest {
request_id: request_id.clone(),
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens,
worker: selection.worker,
lora_name: request.lora_name.clone(),
})
.await
{
if let Err(e) = self.slots.add_request(SequenceRequest {
request_id: request_id.clone(),
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens,
worker: selection.worker,
lora_name: request.lora_name.clone(),
}) {
tracing::warn!("Failed to add request {request_id}: {e}");
}
}
......@@ -413,8 +409,8 @@ mod tests {
let resp = resp.expect("scheduling failed");
assert!(resp.best_worker.worker_id < num_workers as u64);
slots.mark_prefill_completed(&req_id).await.unwrap();
slots.free(&req_id).await.unwrap();
slots.mark_prefill_completed(&req_id).unwrap();
slots.free(&req_id).unwrap();
queue.update().await;
}));
}
......@@ -457,8 +453,8 @@ mod tests {
for _ in 0..num_requests {
queue.update().await;
for rid in &req_ids {
let _ = slots.mark_prefill_completed(rid).await;
let _ = slots.free(rid).await;
let _ = slots.mark_prefill_completed(rid);
let _ = slots.free(rid);
}
}
queue.update().await;
......@@ -499,11 +495,8 @@ mod tests {
assert_eq!(queue.pending_count(), 2);
// Free the first request and update — should drain one from pending
slots
.mark_prefill_completed(&"req-1".to_string())
.await
.unwrap();
slots.free(&"req-1".to_string()).await.unwrap();
slots.mark_prefill_completed(&"req-1".to_string()).unwrap();
slots.free(&"req-1".to_string()).unwrap();
queue.update().await;
// After update, one pending request should have been scheduled
......@@ -514,11 +507,11 @@ mod tests {
);
// Free req-2 and update to drain remaining
let _ = slots.mark_prefill_completed(&"req-2".to_string()).await;
let _ = slots.free(&"req-2".to_string()).await;
let _ = slots.mark_prefill_completed(&"req-2".to_string());
let _ = slots.free(&"req-2".to_string());
queue.update().await;
let _ = slots.mark_prefill_completed(&"req-3".to_string()).await;
let _ = slots.free(&"req-3".to_string()).await;
let _ = slots.mark_prefill_completed(&"req-3".to_string());
let _ = slots.free(&"req-3".to_string());
queue.update().await;
assert_eq!(queue.pending_count(), 0, "all requests should be drained");
......@@ -598,9 +591,8 @@ mod tests {
// Clean up
slots
.mark_prefill_completed(&"after-register".to_string())
.await
.unwrap();
slots.free(&"after-register".to_string()).await.unwrap();
slots.free(&"after-register".to_string()).unwrap();
}
/// Register_workers is additive: calling with a new set does NOT remove old workers.
......@@ -651,8 +643,8 @@ mod tests {
.expect("oneshot dropped")
.expect("scheduling failed");
seen.insert(resp.best_worker.worker_id);
slots.mark_prefill_completed(&req_id).await.unwrap();
slots.free(&req_id).await.unwrap();
slots.mark_prefill_completed(&req_id).unwrap();
slots.free(&req_id).unwrap();
}
assert!(
......@@ -721,9 +713,8 @@ mod tests {
);
slots
.mark_prefill_completed(&"filter-0".to_string())
.await
.unwrap();
slots.free(&"filter-0".to_string()).await.unwrap();
slots.free(&"filter-0".to_string()).unwrap();
}
#[tokio::test(flavor = "multi_thread")]
......@@ -747,9 +738,9 @@ mod tests {
let _resp2 = rx2.await.unwrap().unwrap();
assert_eq!(queue.pending_count(), 0);
let _ = slots.mark_prefill_completed(&"req-1".to_string()).await;
let _ = slots.free(&"req-1".to_string()).await;
let _ = slots.mark_prefill_completed(&"req-2".to_string()).await;
let _ = slots.free(&"req-2".to_string()).await;
let _ = slots.mark_prefill_completed(&"req-1".to_string());
let _ = slots.free(&"req-1".to_string());
let _ = slots.mark_prefill_completed(&"req-2".to_string());
let _ = slots.free(&"req-2".to_string());
}
}
......@@ -83,12 +83,6 @@ pub enum SequenceError {
#[error("Request {request_id} not found")]
RequestNotFound { request_id: String },
#[error("Failed to publish event: {0}")]
PublishFailed(#[from] anyhow::Error),
#[error("Synchronous mutation requires replica_sync=false")]
SyncMutationRequiresNoReplicaSync,
}
/// Bundled parameters for adding a request to the sequence tracker.
......@@ -147,7 +141,7 @@ pub struct ActiveSequencesMultiWorker<P: SequencePublisher> {
request_to_lora: DashMap<RequestId, String>,
block_size: usize,
router_id: u64,
publisher: P,
publisher: Arc<P>,
replica_sync: bool,
worker_type: &'static str,
}
......@@ -172,12 +166,29 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
request_to_lora: DashMap::new(),
block_size,
router_id,
publisher,
publisher: Arc::new(publisher),
replica_sync,
worker_type,
}
}
fn spawn_publish_event(&self, event: ActiveSequenceEvent) {
if !self.replica_sync {
return;
}
let publisher = Arc::clone(&self.publisher);
tokio::spawn(async move {
if let Err(e) = publisher.publish_event(&event).await {
tracing::error!(
request_id = %event.request_id,
worker = ?event.worker,
"failed to publish active sequence event: {e}"
);
}
});
}
/// Spawn a background task that subscribes to replica-sync events from peer routers
/// and applies them to the local state.
pub fn start_replica_sync<S: SequenceSubscriber + 'static>(
......@@ -370,13 +381,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
}
fn ensure_sync_mutation_allowed(&self) -> Result<(), SequenceError> {
if self.replica_sync {
return Err(SequenceError::SyncMutationRequiresNoReplicaSync);
}
Ok(())
}
fn add_request_local(&self, req: SequenceRequest) -> Result<(), SequenceError> {
let SequenceRequest {
request_id,
......@@ -433,29 +437,20 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
Ok(())
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
if self.replica_sync {
let event = ActiveSequenceEvent {
request_id: req.request_id.clone(),
worker: req.worker,
data: ActiveSequenceEventData::AddRequest {
token_sequence: req.token_sequence.clone(),
isl: req.isl,
overlap: req.overlap,
track_prefill_tokens: req.track_prefill_tokens,
expected_output_tokens: req.expected_output_tokens,
},
router_id: self.router_id,
lora_name: req.lora_name.clone(),
};
self.publisher.publish_event(&event).await?;
}
self.add_request_local(req)
}
pub fn add_request_sync(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.ensure_sync_mutation_allowed()?;
pub fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.spawn_publish_event(ActiveSequenceEvent {
request_id: req.request_id.clone(),
worker: req.worker,
data: ActiveSequenceEventData::AddRequest {
token_sequence: req.token_sequence.clone(),
isl: req.isl,
overlap: req.overlap,
track_prefill_tokens: req.track_prefill_tokens,
expected_output_tokens: req.expected_output_tokens,
},
router_id: self.router_id,
lora_name: req.lora_name.clone(),
});
self.add_request_local(req)
}
......@@ -495,7 +490,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
Ok(())
}
async fn mutate_request_worker(
fn mutate_request_worker(
&self,
request_id: &RequestId,
event_data: ActiveSequenceEventData,
......@@ -510,21 +505,17 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
request_id: request_id.clone(),
})?;
if self.replica_sync {
let lora_name = self
.request_to_lora
.get(request_id)
.map(|entry| entry.value().clone());
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: event_data,
router_id: self.router_id,
lora_name,
};
self.publisher.publish_event(&event).await?;
}
let lora_name = self
.request_to_lora
.get(request_id)
.map(|entry| entry.value().clone());
self.spawn_publish_event(ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: event_data,
router_id: self.router_id,
lora_name,
});
self.mutate_request_worker_local(request_id, mutate_fn, remove_mapping)
}
......@@ -537,7 +528,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// This also performs the underlying prefill-complete cleanup via
/// [`ActiveSequences::free`], so callers do not need to call
/// [`Self::mark_prefill_completed`] before freeing a completed request.
pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
pub fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(());
......@@ -551,32 +542,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
},
true,
)
.await
}
pub fn free_sync(&self, request_id: &RequestId) -> Result<(), SequenceError> {
self.ensure_sync_mutation_allowed()?;
if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(());
}
self.mutate_request_worker_local(
request_id,
|seqs, rid| {
seqs.free(rid);
},
true,
)
}
/// Mark prefill as completed for a request.
///
/// Note: Calling this multiple times for the same request is allowed and will be a no-op
/// after the first call (idempotent).
pub async fn mark_prefill_completed(
&self,
request_id: &RequestId,
) -> Result<(), SequenceError> {
pub fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<(), SequenceError> {
self.mutate_request_worker(
request_id,
ActiveSequenceEventData::MarkPrefillCompleted,
......@@ -585,18 +557,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
},
false,
)
.await
}
pub fn mark_prefill_completed_sync(&self, request_id: &RequestId) -> Result<(), SequenceError> {
self.ensure_sync_mutation_allowed()?;
self.mutate_request_worker_local(
request_id,
|seqs, rid| {
seqs.mark_prefill_completed(rid);
},
false,
)
}
/// Add an output block with optional fractional decay weight.
......@@ -895,7 +855,6 @@ mod tests {
worker,
lora_name: None,
})
.await
.unwrap();
assert_eq!(sequences.active_tokens().get(&worker).copied(), Some(0));
......
......@@ -11,10 +11,12 @@ pub mod runtime;
pub mod server;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use crate::config::min_initial_workers_from_env;
use registry::WorkerRegistry;
use server::{AppState, create_router};
......@@ -198,6 +200,33 @@ pub async fn run_with_runtime(
run_common(&config, &registry, cancel_token).await
}
async fn wait_for_min_initial_workers(
registry: &WorkerRegistry,
cancel_token: &CancellationToken,
) -> anyhow::Result<()> {
let min_initial_workers = min_initial_workers_from_env()?;
if min_initial_workers == 0 {
return Ok(());
}
loop {
let registered_workers = registry.list().len();
if registered_workers >= min_initial_workers {
return Ok(());
}
tokio::select! {
_ = cancel_token.cancelled() => {
anyhow::bail!(
"shutdown triggered before {} indexer workers appeared",
min_initial_workers
);
}
_ = tokio::time::sleep(Duration::from_millis(100)) => {}
}
}
}
async fn run_common(
config: &IndexerConfig,
registry: &Arc<WorkerRegistry>,
......@@ -245,6 +274,7 @@ async fn run_common(
}
}
wait_for_min_initial_workers(registry, &cancel_token).await?;
registry.signal_ready();
#[cfg(feature = "metrics")]
......
......@@ -555,6 +555,7 @@ impl ModelManager {
// -- KV Router creation --
#[allow(clippy::too_many_arguments)]
pub async fn kv_chooser_for(
&self,
endpoint: &Endpoint,
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use std::pin::Pin;
use std::time::Duration;
use crate::{
backend::{Backend, ExecutionContext},
......@@ -28,6 +29,7 @@ use crate::{
};
use anyhow::Context as _;
use dynamo_kv_router::config::min_initial_workers_from_env;
use dynamo_runtime::{
DistributedRuntime,
component::Client,
......@@ -47,6 +49,45 @@ pub struct PreparedEngine {
pub request_template: Option<RequestTemplate>,
}
async fn wait_for_min_initial_workers(
client: &Client,
min_initial_workers: usize,
) -> anyhow::Result<()> {
if min_initial_workers == 0 {
return Ok(());
}
if min_initial_workers == 1 {
client.wait_for_instances().await?;
return Ok(());
}
let mut watcher = client.instance_avail_watcher();
loop {
let available = watcher.borrow_and_update().len();
if available >= min_initial_workers {
return Ok(());
}
tokio::time::timeout(Duration::from_secs(120), watcher.changed())
.await
.map_err(|_| {
anyhow::anyhow!(
"timed out waiting for {} initial workers for endpoint {}",
min_initial_workers,
client.endpoint.id()
)
})?
.map_err(|_| {
anyhow::anyhow!(
"instance watcher closed before {} workers appeared for endpoint {}",
min_initial_workers,
client.endpoint.id()
)
})?;
}
}
impl PreparedEngine {
pub fn has_tokenizer(&self) -> bool {
if let Some(card) = self.card.as_ref() {
......@@ -254,6 +295,7 @@ where
let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_tokenizer(tokenizer).into_operator();
let migration = Migration::from_mdc(card, migration_limit, metrics).into_operator();
let min_initial_workers = min_initial_workers_from_env()?;
// For KV routing, use the client from the chooser to ensure shared state
let router_client = if router_mode == RouterMode::KV {
......@@ -265,6 +307,8 @@ where
client.clone()
};
wait_for_min_initial_workers(&router_client, min_initial_workers).await?;
// Get threshold value and wrap monitor for PushRouter
// Note: PushRouter uses active_decode_blocks_threshold for its internal logic
let threshold_value = worker_monitor
......
......@@ -5,7 +5,7 @@ use std::time::Instant;
use anyhow::Result;
use dynamo_kv_router::{
config::{KvRouterConfig, RouterConfigOverride},
config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
indexer::KvRouterError,
protocols::KV_EVENT_SUBJECT,
protocols::{
......@@ -126,7 +126,7 @@ where
pub async fn new(
endpoint: Endpoint,
client: Client,
mut workers_with_configs: RuntimeConfigWatch,
workers_with_configs: RuntimeConfigWatch,
block_size: u32,
selector: Sel,
kv_router_config: Option<KvRouterConfig>,
......@@ -138,17 +138,19 @@ where
kv_router_config.validate()?;
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
let min_initial_workers = min_initial_workers_from_env()?;
let indexer = Indexer::new(component, &kv_router_config, block_size, model_name).await?;
if !kv_router_config.skip_initial_worker_wait {
let _ = workers_with_configs
.wait_for(|m| m.len() >= kv_router_config.min_initial_workers)
if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait {
let mut startup_watch = workers_with_configs.clone();
let _ = startup_watch
.wait_for(|m| m.len() >= min_initial_workers)
.await
.map_err(|_| {
anyhow::anyhow!(
"runtime config watch closed before {} workers appeared",
kv_router_config.min_initial_workers
min_initial_workers
)
})?;
}
......
......@@ -217,44 +217,38 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
seq_manager_1
.add_request(SequenceRequest {
request_id: "request_0".to_string(),
token_sequence: Some(vec![0, 1, 2]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::new(0, 0),
lora_name: None,
})
.await?;
seq_manager_1
.add_request(SequenceRequest {
request_id: "request_1".to_string(),
token_sequence: Some(vec![3, 4]),
isl: 8,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::new(0, 1),
lora_name: None,
})
.await?;
seq_manager_2
.add_request(SequenceRequest {
request_id: "request_2".to_string(),
token_sequence: Some(vec![0, 1, 2, 3]),
isl: 16,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::new(1, 0),
lora_name: None,
})
.await?;
seq_manager_1.add_request(SequenceRequest {
request_id: "request_0".to_string(),
token_sequence: Some(vec![0, 1, 2]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::new(0, 0),
lora_name: None,
})?;
seq_manager_1.add_request(SequenceRequest {
request_id: "request_1".to_string(),
token_sequence: Some(vec![3, 4]),
isl: 8,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::new(0, 1),
lora_name: None,
})?;
seq_manager_2.add_request(SequenceRequest {
request_id: "request_2".to_string(),
token_sequence: Some(vec![0, 1, 2, 3]),
isl: 16,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::new(1, 0),
lora_name: None,
})?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
......@@ -290,10 +284,10 @@ mod tests {
"Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)"
);
seq_manager_1.free(&"request_2".to_string()).await?;
seq_manager_1.free(&"request_2".to_string())?;
seq_manager_2.free(&"request_0".to_string()).await?;
seq_manager_2.free(&"request_1".to_string()).await?;
seq_manager_2.free(&"request_0".to_string())?;
seq_manager_2.free(&"request_1".to_string())?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
......@@ -370,44 +364,38 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
seq_manager_1
.add_request(SequenceRequest {
request_id: "request_0".to_string(),
token_sequence: None,
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(0),
lora_name: None,
})
.await?;
seq_manager_1
.add_request(SequenceRequest {
request_id: "request_1".to_string(),
token_sequence: None,
isl: 8,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(1),
lora_name: None,
})
.await?;
seq_manager_2
.add_request(SequenceRequest {
request_id: "request_2".to_string(),
token_sequence: None,
isl: 16,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None,
})
.await?;
seq_manager_1.add_request(SequenceRequest {
request_id: "request_0".to_string(),
token_sequence: None,
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(0),
lora_name: None,
})?;
seq_manager_1.add_request(SequenceRequest {
request_id: "request_1".to_string(),
token_sequence: None,
isl: 8,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(1),
lora_name: None,
})?;
seq_manager_2.add_request(SequenceRequest {
request_id: "request_2".to_string(),
token_sequence: None,
isl: 16,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None,
})?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
......@@ -430,19 +418,13 @@ mod tests {
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
);
seq_manager_1
.mark_prefill_completed(&"request_2".to_string())
.await?;
seq_manager_1.free(&"request_2".to_string()).await?;
seq_manager_2
.mark_prefill_completed(&"request_0".to_string())
.await?;
seq_manager_2
.mark_prefill_completed(&"request_1".to_string())
.await?;
seq_manager_2.free(&"request_0".to_string()).await?;
seq_manager_2.free(&"request_1".to_string()).await?;
seq_manager_1.mark_prefill_completed(&"request_2".to_string())?;
seq_manager_1.free(&"request_2".to_string())?;
seq_manager_2.mark_prefill_completed(&"request_0".to_string())?;
seq_manager_2.mark_prefill_completed(&"request_1".to_string())?;
seq_manager_2.free(&"request_0".to_string())?;
seq_manager_2.free(&"request_1".to_string())?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
......
......@@ -396,7 +396,7 @@ impl MockEngine {
let mut senders = Vec::with_capacity(args.dp_size as usize);
for dp_rank in 0..args.dp_size {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let (kv_event_publishers, relay_publisher): (
KvEventPublishers,
......@@ -499,12 +499,14 @@ impl MockEngine {
loop {
tokio::select! {
signal_result = output_rx.recv() => {
let Some(signal) = signal_result else {
let Some(output_batch) = signal_result else {
break; // Channel closed
};
if let Some(request_tx) = active_requests_clone.get(&signal.uuid) {
let _ = request_tx.send(signal);
for signal in output_batch {
if let Some(request_tx) = active_requests_clone.get(&signal.uuid) {
let _ = request_tx.send(signal);
}
}
}
_ = cancel_token_cloned.cancelled() => {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment