Unverified Commit a2154ba5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(mocker): batch live output signal sends (#7647)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 06f17011
...@@ -35,7 +35,6 @@ _KV_ROUTER_FIELDS: tuple[str, ...] = ( ...@@ -35,7 +35,6 @@ _KV_ROUTER_FIELDS: tuple[str, ...] = (
"router_queue_threshold", "router_queue_threshold",
"router_event_threads", "router_event_threads",
"router_enable_cache_control", "router_enable_cache_control",
"min_initial_workers",
"router_queue_policy", "router_queue_policy",
"remote_indexer_component", "remote_indexer_component",
) )
...@@ -61,7 +60,6 @@ class KvRouterConfigBase(ConfigBase): ...@@ -61,7 +60,6 @@ class KvRouterConfigBase(ConfigBase):
router_queue_threshold: Optional[float] router_queue_threshold: Optional[float]
router_event_threads: int router_event_threads: int
router_enable_cache_control: bool router_enable_cache_control: bool
min_initial_workers: int
router_queue_policy: str router_queue_policy: str
remote_indexer_component: Optional[str] remote_indexer_component: Optional[str]
...@@ -274,18 +272,6 @@ class KvRouterArgGroup(ArgGroup): ...@@ -274,18 +272,6 @@ class KvRouterArgGroup(ArgGroup):
"requests with nvext.cache_control." "requests with nvext.cache_control."
), ),
) )
add_argument(
g,
flag_name="--router-min-initial-workers",
env_var="DYN_ROUTER_MIN_INITIAL_WORKERS",
default=1,
help=(
"KV Router: Minimum number of workers that must be discovered before "
"router startup continues. Ignored when skip_initial_worker_wait is enabled."
),
arg_type=int,
dest="min_initial_workers",
)
add_argument( add_argument(
g, g,
flag_name="--router-queue-policy", flag_name="--router-queue-policy",
......
...@@ -50,6 +50,7 @@ class FrontendConfig(KvRouterConfigBase): ...@@ -50,6 +50,7 @@ class FrontendConfig(KvRouterConfigBase):
tls_key_path: Optional[pathlib.Path] tls_key_path: Optional[pathlib.Path]
router_mode: str router_mode: str
min_initial_workers: int
namespace: Optional[str] = None namespace: Optional[str] = None
namespace_prefix: Optional[str] = None namespace_prefix: Optional[str] = None
enforce_disagg: bool enforce_disagg: bool
...@@ -90,6 +91,8 @@ class FrontendConfig(KvRouterConfigBase): ...@@ -90,6 +91,8 @@ class FrontendConfig(KvRouterConfigBase):
raise ValueError( raise ValueError(
"--migration-limit must be between 0 and 4294967295 (0=disabled)" "--migration-limit must be between 0 and 4294967295 (0=disabled)"
) )
if self.min_initial_workers < 0:
raise ValueError("--router-min-initial-workers must be >= 0")
if self.router_enable_cache_control and self.router_mode != "kv": if self.router_enable_cache_control and self.router_mode != "kv":
raise ValueError("--enable-cache-control requires --router-mode=kv") raise ValueError("--enable-cache-control requires --router-mode=kv")
if self.tokenizer_backend not in self._VALID_TOKENIZER_BACKENDS: if self.tokenizer_backend not in self._VALID_TOKENIZER_BACKENDS:
...@@ -188,6 +191,20 @@ class FrontendArgGroup(ArgGroup): ...@@ -188,6 +191,20 @@ class FrontendArgGroup(ArgGroup):
"synchronous prefill path.", "synchronous prefill path.",
choices=["round-robin", "random", "power-of-two", "kv", "direct"], choices=["round-robin", "random", "power-of-two", "kv", "direct"],
) )
add_argument(
g,
flag_name="--router-min-initial-workers",
env_var="DYN_ROUTER_MIN_INITIAL_WORKERS",
default=0,
help=(
"Minimum number of workers required before router startup continues. "
"This is exported as DYN_ROUTER_MIN_INITIAL_WORKERS so the generic "
"push-router path and the KV router's config-ready worker gate share "
"the same startup threshold. Set to 0 to disable the startup wait."
),
arg_type=int,
dest="min_initial_workers",
)
# KV router options (shared with dynamo.router) # KV router options (shared with dynamo.router)
KvRouterArgGroup().add_arguments(parser) KvRouterArgGroup().add_arguments(parser)
......
...@@ -47,6 +47,8 @@ if TYPE_CHECKING: ...@@ -47,6 +47,8 @@ if TYPE_CHECKING:
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MIN_INITIAL_WORKERS_ENV = "DYN_ROUTER_MIN_INITIAL_WORKERS"
def setup_engine_factory( def setup_engine_factory(
runtime: DistributedRuntime, runtime: DistributedRuntime,
...@@ -232,6 +234,7 @@ async def async_main(): ...@@ -232,6 +234,7 @@ async def async_main():
router_mode = RouterMode.RoundRobin router_mode = RouterMode.RoundRobin
kv_router_config = None kv_router_config = None
os.environ[MIN_INITIAL_WORKERS_ENV] = str(config.min_initial_workers)
router_config = RouterConfig( router_config = RouterConfig(
router_mode, router_mode,
kv_router_config, kv_router_config,
......
...@@ -60,6 +60,9 @@ async def main(): ...@@ -60,6 +60,9 @@ async def main():
kv_router_config=kv_router_config kv_router_config=kv_router_config
) )
# Optional startup gate shared with the frontend and standalone indexer:
# os.environ["DYN_ROUTER_MIN_INITIAL_WORKERS"] = "2"
# Your input tokens # Your input tokens
token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
......
...@@ -16,6 +16,10 @@ This is distinct from the [Standalone Router](../../../components/src/dynamo/rou ...@@ -16,6 +16,10 @@ This is distinct from the [Standalone Router](../../../components/src/dynamo/rou
The HTTP API follows the [Mooncake KV Indexer RFC](https://github.com/kvcache-ai/Mooncake/issues/1403) conventions. The HTTP API follows the [Mooncake KV Indexer RFC](https://github.com/kvcache-ai/Mooncake/issues/1403) conventions.
`DYN_ROUTER_MIN_INITIAL_WORKERS` is also honored here. When set to a positive integer, the
standalone indexer waits for that many workers to register before opening its startup-ready
gate, matching the frontend/router startup behavior.
## Multi-Model and Multi-Tenant Support ## Multi-Model and Multi-Tenant Support
The indexer maintains one radix tree per `(model_name, tenant_id)` pair. Workers registered with different model names or tenant IDs are isolated into separate indexers — queries against one model/tenant never return scores from another. The indexer maintains one radix tree per `(model_name, tenant_id)` pair. Workers registered with different model names or tenant IDs are isolated into separate indexers — queries against one model/tenant never return scores from another.
...@@ -143,6 +147,12 @@ In runtime mode, workers are discovered automatically via MDC. The `--workers` f ...@@ -143,6 +147,12 @@ In runtime mode, workers are discovered automatically via MDC. The `--workers` f
| `--component-name` | `kv-indexer` | Component name for this indexer in the Dynamo runtime | | `--component-name` | `kv-indexer` | Component name for this indexer in the Dynamo runtime |
| `--worker-component` | `backend` | Component name that workers register under for event-plane subscription | | `--worker-component` | `backend` | Component name that workers register under for event-plane subscription |
### Shared Startup Gate
Set `DYN_ROUTER_MIN_INITIAL_WORKERS=<n>` to require at least `<n>` workers before the
standalone indexer, frontend push-router path, and KV router config-ready gate all proceed.
Leave it unset or set it to `0` to disable the startup wait.
## HTTP API ## HTTP API
### `GET /health` — Liveness check ### `GET /health` — Liveness check
......
...@@ -378,24 +378,22 @@ async fn apply_entry( ...@@ -378,24 +378,22 @@ async fn apply_entry(
isl, isl,
OverlapScores::default(), OverlapScores::default(),
); );
let _ = multi let _ = multi.add_request(SequenceRequest {
.add_request(SequenceRequest { request_id,
request_id, token_sequence: Some(block_hashes),
token_sequence: Some(block_hashes), isl,
isl, overlap: 0,
overlap: 0, track_prefill_tokens: true,
track_prefill_tokens: true, expected_output_tokens: Some(output_length as u32),
expected_output_tokens: Some(output_length as u32), worker,
worker, lora_name: None,
lora_name: None, });
})
.await;
} }
SequenceTraceEntry::PrefillComplete { request_id } => { SequenceTraceEntry::PrefillComplete { request_id } => {
let _ = multi.mark_prefill_completed(&request_id).await; let _ = multi.mark_prefill_completed(&request_id);
} }
SequenceTraceEntry::Free { request_id } => { SequenceTraceEntry::Free { request_id } => {
let _ = multi.free(&request_id).await; let _ = multi.free(&request_id);
} }
} }
} }
......
...@@ -389,7 +389,7 @@ async fn replay_worker_trace( ...@@ -389,7 +389,7 @@ async fn replay_worker_trace(
.rescale_ready_span(trace_simulation_duration_ms)? .rescale_ready_span(trace_simulation_duration_ms)?
.into_trace_driver()?; .into_trace_driver()?;
let collector = EventCollector::new(); let collector = EventCollector::new();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler = Scheduler::new( let scheduler = Scheduler::new(
sched_args, sched_args,
0, 0,
...@@ -429,32 +429,40 @@ async fn replay_worker_trace( ...@@ -429,32 +429,40 @@ async fn replay_worker_trace(
let deadline = start + Duration::from_secs_f64((next_ready_ms.max(0.0)) / 1000.0); let deadline = start + Duration::from_secs_f64((next_ready_ms.max(0.0)) / 1000.0);
tokio::select! { tokio::select! {
maybe_signal = output_rx.recv() => { maybe_signal = output_rx.recv() => {
let Some(signal) = maybe_signal else { let Some(output_batch) = maybe_signal else {
anyhow::bail!("scheduler ended before workload replay drained"); anyhow::bail!("scheduler ended before workload replay drained");
}; };
output_signals.push(TimedOutputSignal { let timestamp_us = start.elapsed().as_micros() as u64;
signal: signal.clone(), let completion_ms = start.elapsed().as_secs_f64() * 1000.0;
timestamp_us: start.elapsed().as_micros() as u64, for signal in output_batch {
}); output_signals.push(TimedOutputSignal {
if signal.completed { signal: signal.clone(),
completed_turns += 1; timestamp_us,
driver.on_complete(signal.uuid, start.elapsed().as_secs_f64() * 1000.0)?; });
if signal.completed {
completed_turns += 1;
driver.on_complete(signal.uuid, completion_ms)?;
}
} }
} }
_ = tokio::time::sleep_until(deadline) => {} _ = tokio::time::sleep_until(deadline) => {}
} }
} }
None => { None => {
let Some(signal) = output_rx.recv().await else { let Some(output_batch) = output_rx.recv().await else {
anyhow::bail!("scheduler ended before workload replay drained"); anyhow::bail!("scheduler ended before workload replay drained");
}; };
output_signals.push(TimedOutputSignal { let timestamp_us = start.elapsed().as_micros() as u64;
signal: signal.clone(), let completion_ms = start.elapsed().as_secs_f64() * 1000.0;
timestamp_us: start.elapsed().as_micros() as u64, for signal in output_batch {
}); output_signals.push(TimedOutputSignal {
if signal.completed { signal: signal.clone(),
completed_turns += 1; timestamp_us,
driver.on_complete(signal.uuid, start.elapsed().as_secs_f64() * 1000.0)?; });
if signal.completed {
completed_turns += 1;
driver.on_complete(signal.uuid, completion_ms)?;
}
} }
} }
} }
......
...@@ -58,7 +58,7 @@ impl KvRouterConfig { ...@@ -58,7 +58,7 @@ impl KvRouterConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_enable_cache_control=false, min_initial_workers=1, router_queue_policy="fcfs", remote_indexer_component=None))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_enable_cache_control=false, router_queue_policy="fcfs", remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
...@@ -78,7 +78,6 @@ impl KvRouterConfig { ...@@ -78,7 +78,6 @@ impl KvRouterConfig {
router_queue_threshold: Option<f64>, router_queue_threshold: Option<f64>,
router_event_threads: u32, router_event_threads: u32,
router_enable_cache_control: bool, router_enable_cache_control: bool,
min_initial_workers: usize,
router_queue_policy: &str, router_queue_policy: &str,
remote_indexer_component: Option<String>, remote_indexer_component: Option<String>,
) -> Self { ) -> Self {
...@@ -102,7 +101,6 @@ impl KvRouterConfig { ...@@ -102,7 +101,6 @@ impl KvRouterConfig {
router_event_threads, router_event_threads,
router_enable_cache_control, router_enable_cache_control,
skip_initial_worker_wait: false, skip_initial_worker_wait: false,
min_initial_workers,
router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| { router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| {
panic!("invalid router_queue_policy: {router_queue_policy:?}") panic!("invalid router_queue_policy: {router_queue_policy:?}")
}), }),
......
...@@ -1179,7 +1179,6 @@ class KvRouterConfig: ...@@ -1179,7 +1179,6 @@ class KvRouterConfig:
router_queue_threshold: Optional[float] = 4.0, router_queue_threshold: Optional[float] = 4.0,
router_event_threads: int = 4, router_event_threads: int = 4,
router_enable_cache_control: bool = False, router_enable_cache_control: bool = False,
min_initial_workers: int = 1,
router_queue_policy: str = "fcfs", router_queue_policy: str = "fcfs",
) -> None: ) -> None:
""" """
...@@ -1213,9 +1212,6 @@ class KvRouterConfig: ...@@ -1213,9 +1212,6 @@ class KvRouterConfig:
When > 1, uses a concurrent radix tree with a thread pool. When > 1, uses a concurrent radix tree with a thread pool.
router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's
cache_control service mesh endpoint (default: False). cache_control service mesh endpoint (default: False).
min_initial_workers: Minimum number of discovered workers required before
router startup continues (default: 1). Ignored when
skip_initial_worker_wait is enabled.
router_queue_policy: Scheduling policy for the router queue (default: "fcfs"). router_queue_policy: Scheduling policy for the router queue (default: "fcfs").
"fcfs": first-come first-served with priority bumps — optimizes tail TTFT. "fcfs": first-come first-served with priority bumps — optimizes tail TTFT.
"lcfs": last-come first-served with priority bumps — intentionally worsens tail behavior for policy comparisons. "lcfs": last-come first-served with priority bumps — intentionally worsens tail behavior for policy comparisons.
......
...@@ -84,7 +84,6 @@ def _router_config_payload(): ...@@ -84,7 +84,6 @@ def _router_config_payload():
"router_prune_target_ratio": 0.8, "router_prune_target_ratio": 0.8,
"router_enable_cache_control": False, "router_enable_cache_control": False,
"skip_initial_worker_wait": False, "skip_initial_worker_wait": False,
"min_initial_workers": 1,
"remote_indexer_component": None, "remote_indexer_component": None,
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::env::{self, VarError};
use std::fmt; use std::fmt;
use std::str::FromStr; use std::str::FromStr;
...@@ -13,14 +14,26 @@ use crate::protocols::{ ...@@ -13,14 +14,26 @@ use crate::protocols::{
BlockHashOptions, LocalBlockHash, compute_block_hash_for_seq, compute_seq_hash_for_block, BlockHashOptions, LocalBlockHash, compute_block_hash_for_seq, compute_seq_hash_for_block,
}; };
const fn default_min_initial_workers() -> usize {
1
}
const fn default_track_prefill_tokens() -> bool { const fn default_track_prefill_tokens() -> bool {
true true
} }
pub const DYN_ROUTER_MIN_INITIAL_WORKERS: &str = "DYN_ROUTER_MIN_INITIAL_WORKERS";
pub fn min_initial_workers_from_env() -> anyhow::Result<usize> {
match env::var(DYN_ROUTER_MIN_INITIAL_WORKERS) {
Ok(value) => value.parse::<usize>().map_err(|error| {
anyhow::anyhow!(
"{DYN_ROUTER_MIN_INITIAL_WORKERS} must be a non-negative integer, got {value:?}: {error}"
)
}),
Err(VarError::NotPresent) => Ok(0),
Err(VarError::NotUnicode(_)) => {
anyhow::bail!("{DYN_ROUTER_MIN_INITIAL_WORKERS} must be valid unicode")
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum RouterQueuePolicy { pub enum RouterQueuePolicy {
...@@ -149,17 +162,8 @@ pub struct KvRouterConfig { ...@@ -149,17 +162,8 @@ pub struct KvRouterConfig {
/// When false (default), cache_control is ignored and no cache_control client is created. /// When false (default), cache_control is ignored and no cache_control client is created.
pub router_enable_cache_control: bool, pub router_enable_cache_control: bool,
/// Skip blocking for workers at init time (default: false).
/// When true, the router starts immediately without waiting for discovery-based
/// workers and workers are provided externally per-request (e.g., EPP).
pub skip_initial_worker_wait: bool, pub skip_initial_worker_wait: bool,
/// Minimum number of workers that must be discovered before router startup continues.
/// Default: 1. Ignored when skip_initial_worker_wait=true.
#[serde(default = "default_min_initial_workers")]
#[validate(range(min = 1))]
pub min_initial_workers: usize,
/// Scheduling policy for the router queue. /// Scheduling policy for the router queue.
/// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT. /// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
/// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT. /// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
...@@ -194,7 +198,6 @@ impl Default for KvRouterConfig { ...@@ -194,7 +198,6 @@ impl Default for KvRouterConfig {
router_event_threads: 4, router_event_threads: 4,
router_enable_cache_control: false, router_enable_cache_control: false,
skip_initial_worker_wait: false, skip_initial_worker_wait: false,
min_initial_workers: default_min_initial_workers(),
router_queue_policy: RouterQueuePolicy::default(), router_queue_policy: RouterQueuePolicy::default(),
remote_indexer_component: None, remote_indexer_component: None,
} }
...@@ -309,25 +312,11 @@ mod tests { ...@@ -309,25 +312,11 @@ mod tests {
assert_eq!(deserialized, RouterQueuePolicy::Lcfs); assert_eq!(deserialized, RouterQueuePolicy::Lcfs);
} }
#[test]
fn kv_router_config_defaults_to_one_initial_worker() {
assert_eq!(KvRouterConfig::default().min_initial_workers, 1);
}
#[test] #[test]
fn kv_router_config_defaults_to_tracking_prefill_tokens() { fn kv_router_config_defaults_to_tracking_prefill_tokens() {
assert!(KvRouterConfig::default().router_track_prefill_tokens); assert!(KvRouterConfig::default().router_track_prefill_tokens);
} }
#[test]
fn kv_router_config_rejects_zero_initial_workers() {
let cfg = KvRouterConfig {
min_initial_workers: 0,
..KvRouterConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test] #[test]
fn compute_seq_hashes_for_tracking_uses_mm_hashes() { fn compute_seq_hashes_for_tracking_uses_mm_hashes() {
let cfg = KvRouterConfig::default(); let cfg = KvRouterConfig::default();
......
...@@ -192,19 +192,17 @@ where ...@@ -192,19 +192,17 @@ where
} }
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> { pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req).await self.slots.add_request(req)
} }
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> { pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots self.slots.mark_prefill_completed(&request_id.to_string())?;
.mark_prefill_completed(&request_id.to_string())
.await?;
self.queue.update().await; self.queue.update().await;
Ok(()) Ok(())
} }
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> { pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string()).await?; self.slots.free(&request_id.to_string())?;
self.queue.update().await; self.queue.update().await;
Ok(()) Ok(())
} }
......
...@@ -231,20 +231,16 @@ impl< ...@@ -231,20 +231,16 @@ impl<
return; return;
}; };
if let Err(e) = self if let Err(e) = self.slots.add_request(SequenceRequest {
.slots request_id: request_id.clone(),
.add_request(SequenceRequest { token_sequence: request.token_seq,
request_id: request_id.clone(), isl: request.isl_tokens,
token_sequence: request.token_seq, overlap: selection.overlap_blocks,
isl: request.isl_tokens, track_prefill_tokens: request.track_prefill_tokens,
overlap: selection.overlap_blocks, expected_output_tokens: request.expected_output_tokens,
track_prefill_tokens: request.track_prefill_tokens, worker: selection.worker,
expected_output_tokens: request.expected_output_tokens, lora_name: request.lora_name.clone(),
worker: selection.worker, }) {
lora_name: request.lora_name.clone(),
})
.await
{
tracing::warn!("Failed to add request {request_id}: {e}"); tracing::warn!("Failed to add request {request_id}: {e}");
} }
} }
...@@ -413,8 +409,8 @@ mod tests { ...@@ -413,8 +409,8 @@ mod tests {
let resp = resp.expect("scheduling failed"); let resp = resp.expect("scheduling failed");
assert!(resp.best_worker.worker_id < num_workers as u64); assert!(resp.best_worker.worker_id < num_workers as u64);
slots.mark_prefill_completed(&req_id).await.unwrap(); slots.mark_prefill_completed(&req_id).unwrap();
slots.free(&req_id).await.unwrap(); slots.free(&req_id).unwrap();
queue.update().await; queue.update().await;
})); }));
} }
...@@ -457,8 +453,8 @@ mod tests { ...@@ -457,8 +453,8 @@ mod tests {
for _ in 0..num_requests { for _ in 0..num_requests {
queue.update().await; queue.update().await;
for rid in &req_ids { for rid in &req_ids {
let _ = slots.mark_prefill_completed(rid).await; let _ = slots.mark_prefill_completed(rid);
let _ = slots.free(rid).await; let _ = slots.free(rid);
} }
} }
queue.update().await; queue.update().await;
...@@ -499,11 +495,8 @@ mod tests { ...@@ -499,11 +495,8 @@ mod tests {
assert_eq!(queue.pending_count(), 2); assert_eq!(queue.pending_count(), 2);
// Free the first request and update — should drain one from pending // Free the first request and update — should drain one from pending
slots slots.mark_prefill_completed(&"req-1".to_string()).unwrap();
.mark_prefill_completed(&"req-1".to_string()) slots.free(&"req-1".to_string()).unwrap();
.await
.unwrap();
slots.free(&"req-1".to_string()).await.unwrap();
queue.update().await; queue.update().await;
// After update, one pending request should have been scheduled // After update, one pending request should have been scheduled
...@@ -514,11 +507,11 @@ mod tests { ...@@ -514,11 +507,11 @@ mod tests {
); );
// Free req-2 and update to drain remaining // Free req-2 and update to drain remaining
let _ = slots.mark_prefill_completed(&"req-2".to_string()).await; let _ = slots.mark_prefill_completed(&"req-2".to_string());
let _ = slots.free(&"req-2".to_string()).await; let _ = slots.free(&"req-2".to_string());
queue.update().await; queue.update().await;
let _ = slots.mark_prefill_completed(&"req-3".to_string()).await; let _ = slots.mark_prefill_completed(&"req-3".to_string());
let _ = slots.free(&"req-3".to_string()).await; let _ = slots.free(&"req-3".to_string());
queue.update().await; queue.update().await;
assert_eq!(queue.pending_count(), 0, "all requests should be drained"); assert_eq!(queue.pending_count(), 0, "all requests should be drained");
...@@ -598,9 +591,8 @@ mod tests { ...@@ -598,9 +591,8 @@ mod tests {
// Clean up // Clean up
slots slots
.mark_prefill_completed(&"after-register".to_string()) .mark_prefill_completed(&"after-register".to_string())
.await
.unwrap(); .unwrap();
slots.free(&"after-register".to_string()).await.unwrap(); slots.free(&"after-register".to_string()).unwrap();
} }
/// Register_workers is additive: calling with a new set does NOT remove old workers. /// Register_workers is additive: calling with a new set does NOT remove old workers.
...@@ -651,8 +643,8 @@ mod tests { ...@@ -651,8 +643,8 @@ mod tests {
.expect("oneshot dropped") .expect("oneshot dropped")
.expect("scheduling failed"); .expect("scheduling failed");
seen.insert(resp.best_worker.worker_id); seen.insert(resp.best_worker.worker_id);
slots.mark_prefill_completed(&req_id).await.unwrap(); slots.mark_prefill_completed(&req_id).unwrap();
slots.free(&req_id).await.unwrap(); slots.free(&req_id).unwrap();
} }
assert!( assert!(
...@@ -721,9 +713,8 @@ mod tests { ...@@ -721,9 +713,8 @@ mod tests {
); );
slots slots
.mark_prefill_completed(&"filter-0".to_string()) .mark_prefill_completed(&"filter-0".to_string())
.await
.unwrap(); .unwrap();
slots.free(&"filter-0".to_string()).await.unwrap(); slots.free(&"filter-0".to_string()).unwrap();
} }
#[tokio::test(flavor = "multi_thread")] #[tokio::test(flavor = "multi_thread")]
...@@ -747,9 +738,9 @@ mod tests { ...@@ -747,9 +738,9 @@ mod tests {
let _resp2 = rx2.await.unwrap().unwrap(); let _resp2 = rx2.await.unwrap().unwrap();
assert_eq!(queue.pending_count(), 0); assert_eq!(queue.pending_count(), 0);
let _ = slots.mark_prefill_completed(&"req-1".to_string()).await; let _ = slots.mark_prefill_completed(&"req-1".to_string());
let _ = slots.free(&"req-1".to_string()).await; let _ = slots.free(&"req-1".to_string());
let _ = slots.mark_prefill_completed(&"req-2".to_string()).await; let _ = slots.mark_prefill_completed(&"req-2".to_string());
let _ = slots.free(&"req-2".to_string()).await; let _ = slots.free(&"req-2".to_string());
} }
} }
...@@ -83,12 +83,6 @@ pub enum SequenceError { ...@@ -83,12 +83,6 @@ pub enum SequenceError {
#[error("Request {request_id} not found")] #[error("Request {request_id} not found")]
RequestNotFound { request_id: String }, RequestNotFound { request_id: String },
#[error("Failed to publish event: {0}")]
PublishFailed(#[from] anyhow::Error),
#[error("Synchronous mutation requires replica_sync=false")]
SyncMutationRequiresNoReplicaSync,
} }
/// Bundled parameters for adding a request to the sequence tracker. /// Bundled parameters for adding a request to the sequence tracker.
...@@ -147,7 +141,7 @@ pub struct ActiveSequencesMultiWorker<P: SequencePublisher> { ...@@ -147,7 +141,7 @@ pub struct ActiveSequencesMultiWorker<P: SequencePublisher> {
request_to_lora: DashMap<RequestId, String>, request_to_lora: DashMap<RequestId, String>,
block_size: usize, block_size: usize,
router_id: u64, router_id: u64,
publisher: P, publisher: Arc<P>,
replica_sync: bool, replica_sync: bool,
worker_type: &'static str, worker_type: &'static str,
} }
...@@ -172,12 +166,29 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -172,12 +166,29 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
request_to_lora: DashMap::new(), request_to_lora: DashMap::new(),
block_size, block_size,
router_id, router_id,
publisher, publisher: Arc::new(publisher),
replica_sync, replica_sync,
worker_type, worker_type,
} }
} }
fn spawn_publish_event(&self, event: ActiveSequenceEvent) {
if !self.replica_sync {
return;
}
let publisher = Arc::clone(&self.publisher);
tokio::spawn(async move {
if let Err(e) = publisher.publish_event(&event).await {
tracing::error!(
request_id = %event.request_id,
worker = ?event.worker,
"failed to publish active sequence event: {e}"
);
}
});
}
/// Spawn a background task that subscribes to replica-sync events from peer routers /// Spawn a background task that subscribes to replica-sync events from peer routers
/// and applies them to the local state. /// and applies them to the local state.
pub fn start_replica_sync<S: SequenceSubscriber + 'static>( pub fn start_replica_sync<S: SequenceSubscriber + 'static>(
...@@ -370,13 +381,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -370,13 +381,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
} }
} }
fn ensure_sync_mutation_allowed(&self) -> Result<(), SequenceError> {
if self.replica_sync {
return Err(SequenceError::SyncMutationRequiresNoReplicaSync);
}
Ok(())
}
fn add_request_local(&self, req: SequenceRequest) -> Result<(), SequenceError> { fn add_request_local(&self, req: SequenceRequest) -> Result<(), SequenceError> {
let SequenceRequest { let SequenceRequest {
request_id, request_id,
...@@ -433,29 +437,20 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -433,29 +437,20 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
Ok(()) Ok(())
} }
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> { pub fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
if self.replica_sync { self.spawn_publish_event(ActiveSequenceEvent {
let event = ActiveSequenceEvent { request_id: req.request_id.clone(),
request_id: req.request_id.clone(), worker: req.worker,
worker: req.worker, data: ActiveSequenceEventData::AddRequest {
data: ActiveSequenceEventData::AddRequest { token_sequence: req.token_sequence.clone(),
token_sequence: req.token_sequence.clone(), isl: req.isl,
isl: req.isl, overlap: req.overlap,
overlap: req.overlap, track_prefill_tokens: req.track_prefill_tokens,
track_prefill_tokens: req.track_prefill_tokens, expected_output_tokens: req.expected_output_tokens,
expected_output_tokens: req.expected_output_tokens, },
}, router_id: self.router_id,
router_id: self.router_id, lora_name: req.lora_name.clone(),
lora_name: req.lora_name.clone(), });
};
self.publisher.publish_event(&event).await?;
}
self.add_request_local(req)
}
pub fn add_request_sync(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.ensure_sync_mutation_allowed()?;
self.add_request_local(req) self.add_request_local(req)
} }
...@@ -495,7 +490,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -495,7 +490,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
Ok(()) Ok(())
} }
async fn mutate_request_worker( fn mutate_request_worker(
&self, &self,
request_id: &RequestId, request_id: &RequestId,
event_data: ActiveSequenceEventData, event_data: ActiveSequenceEventData,
...@@ -510,21 +505,17 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -510,21 +505,17 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
request_id: request_id.clone(), request_id: request_id.clone(),
})?; })?;
if self.replica_sync { let lora_name = self
let lora_name = self .request_to_lora
.request_to_lora .get(request_id)
.get(request_id) .map(|entry| entry.value().clone());
.map(|entry| entry.value().clone()); self.spawn_publish_event(ActiveSequenceEvent {
request_id: request_id.clone(),
let event = ActiveSequenceEvent { worker,
request_id: request_id.clone(), data: event_data,
worker, router_id: self.router_id,
data: event_data, lora_name,
router_id: self.router_id, });
lora_name,
};
self.publisher.publish_event(&event).await?;
}
self.mutate_request_worker_local(request_id, mutate_fn, remove_mapping) self.mutate_request_worker_local(request_id, mutate_fn, remove_mapping)
} }
...@@ -537,7 +528,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -537,7 +528,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// This also performs the underlying prefill-complete cleanup via /// This also performs the underlying prefill-complete cleanup via
/// [`ActiveSequences::free`], so callers do not need to call /// [`ActiveSequences::free`], so callers do not need to call
/// [`Self::mark_prefill_completed`] before freeing a completed request. /// [`Self::mark_prefill_completed`] before freeing a completed request.
pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> { pub fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
if !self.request_to_worker.contains_key(request_id) { if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)"); tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(()); return Ok(());
...@@ -551,32 +542,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -551,32 +542,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}, },
true, true,
) )
.await
}
pub fn free_sync(&self, request_id: &RequestId) -> Result<(), SequenceError> {
self.ensure_sync_mutation_allowed()?;
if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(());
}
self.mutate_request_worker_local(
request_id,
|seqs, rid| {
seqs.free(rid);
},
true,
)
} }
/// Mark prefill as completed for a request. /// Mark prefill as completed for a request.
/// ///
/// Note: Calling this multiple times for the same request is allowed and will be a no-op /// Note: Calling this multiple times for the same request is allowed and will be a no-op
/// after the first call (idempotent). /// after the first call (idempotent).
pub async fn mark_prefill_completed( pub fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<(), SequenceError> {
&self,
request_id: &RequestId,
) -> Result<(), SequenceError> {
self.mutate_request_worker( self.mutate_request_worker(
request_id, request_id,
ActiveSequenceEventData::MarkPrefillCompleted, ActiveSequenceEventData::MarkPrefillCompleted,
...@@ -585,18 +557,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -585,18 +557,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}, },
false, false,
) )
.await
}
pub fn mark_prefill_completed_sync(&self, request_id: &RequestId) -> Result<(), SequenceError> {
self.ensure_sync_mutation_allowed()?;
self.mutate_request_worker_local(
request_id,
|seqs, rid| {
seqs.mark_prefill_completed(rid);
},
false,
)
} }
/// Add an output block with optional fractional decay weight. /// Add an output block with optional fractional decay weight.
...@@ -895,7 +855,6 @@ mod tests { ...@@ -895,7 +855,6 @@ mod tests {
worker, worker,
lora_name: None, lora_name: None,
}) })
.await
.unwrap(); .unwrap();
assert_eq!(sequences.active_tokens().get(&worker).copied(), Some(0)); assert_eq!(sequences.active_tokens().get(&worker).copied(), Some(0));
......
...@@ -11,10 +11,12 @@ pub mod runtime; ...@@ -11,10 +11,12 @@ pub mod runtime;
pub mod server; pub mod server;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::config::min_initial_workers_from_env;
use registry::WorkerRegistry; use registry::WorkerRegistry;
use server::{AppState, create_router}; use server::{AppState, create_router};
...@@ -198,6 +200,33 @@ pub async fn run_with_runtime( ...@@ -198,6 +200,33 @@ pub async fn run_with_runtime(
run_common(&config, &registry, cancel_token).await run_common(&config, &registry, cancel_token).await
} }
async fn wait_for_min_initial_workers(
registry: &WorkerRegistry,
cancel_token: &CancellationToken,
) -> anyhow::Result<()> {
let min_initial_workers = min_initial_workers_from_env()?;
if min_initial_workers == 0 {
return Ok(());
}
loop {
let registered_workers = registry.list().len();
if registered_workers >= min_initial_workers {
return Ok(());
}
tokio::select! {
_ = cancel_token.cancelled() => {
anyhow::bail!(
"shutdown triggered before {} indexer workers appeared",
min_initial_workers
);
}
_ = tokio::time::sleep(Duration::from_millis(100)) => {}
}
}
}
async fn run_common( async fn run_common(
config: &IndexerConfig, config: &IndexerConfig,
registry: &Arc<WorkerRegistry>, registry: &Arc<WorkerRegistry>,
...@@ -245,6 +274,7 @@ async fn run_common( ...@@ -245,6 +274,7 @@ async fn run_common(
} }
} }
wait_for_min_initial_workers(registry, &cancel_token).await?;
registry.signal_ready(); registry.signal_ready();
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
......
...@@ -555,6 +555,7 @@ impl ModelManager { ...@@ -555,6 +555,7 @@ impl ModelManager {
// -- KV Router creation -- // -- KV Router creation --
#[allow(clippy::too_many_arguments)]
pub async fn kv_chooser_for( pub async fn kv_chooser_for(
&self, &self,
endpoint: &Endpoint, endpoint: &Endpoint,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::pin::Pin; use std::pin::Pin;
use std::time::Duration;
use crate::{ use crate::{
backend::{Backend, ExecutionContext}, backend::{Backend, ExecutionContext},
...@@ -28,6 +29,7 @@ use crate::{ ...@@ -28,6 +29,7 @@ use crate::{
}; };
use anyhow::Context as _; use anyhow::Context as _;
use dynamo_kv_router::config::min_initial_workers_from_env;
use dynamo_runtime::{ use dynamo_runtime::{
DistributedRuntime, DistributedRuntime,
component::Client, component::Client,
...@@ -47,6 +49,45 @@ pub struct PreparedEngine { ...@@ -47,6 +49,45 @@ pub struct PreparedEngine {
pub request_template: Option<RequestTemplate>, pub request_template: Option<RequestTemplate>,
} }
async fn wait_for_min_initial_workers(
client: &Client,
min_initial_workers: usize,
) -> anyhow::Result<()> {
if min_initial_workers == 0 {
return Ok(());
}
if min_initial_workers == 1 {
client.wait_for_instances().await?;
return Ok(());
}
let mut watcher = client.instance_avail_watcher();
loop {
let available = watcher.borrow_and_update().len();
if available >= min_initial_workers {
return Ok(());
}
tokio::time::timeout(Duration::from_secs(120), watcher.changed())
.await
.map_err(|_| {
anyhow::anyhow!(
"timed out waiting for {} initial workers for endpoint {}",
min_initial_workers,
client.endpoint.id()
)
})?
.map_err(|_| {
anyhow::anyhow!(
"instance watcher closed before {} workers appeared for endpoint {}",
min_initial_workers,
client.endpoint.id()
)
})?;
}
}
impl PreparedEngine { impl PreparedEngine {
pub fn has_tokenizer(&self) -> bool { pub fn has_tokenizer(&self) -> bool {
if let Some(card) = self.card.as_ref() { if let Some(card) = self.card.as_ref() {
...@@ -254,6 +295,7 @@ where ...@@ -254,6 +295,7 @@ where
let preprocessor_op = preprocessor.into_operator(); let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_tokenizer(tokenizer).into_operator(); let backend = Backend::from_tokenizer(tokenizer).into_operator();
let migration = Migration::from_mdc(card, migration_limit, metrics).into_operator(); let migration = Migration::from_mdc(card, migration_limit, metrics).into_operator();
let min_initial_workers = min_initial_workers_from_env()?;
// For KV routing, use the client from the chooser to ensure shared state // For KV routing, use the client from the chooser to ensure shared state
let router_client = if router_mode == RouterMode::KV { let router_client = if router_mode == RouterMode::KV {
...@@ -265,6 +307,8 @@ where ...@@ -265,6 +307,8 @@ where
client.clone() client.clone()
}; };
wait_for_min_initial_workers(&router_client, min_initial_workers).await?;
// Get threshold value and wrap monitor for PushRouter // Get threshold value and wrap monitor for PushRouter
// Note: PushRouter uses active_decode_blocks_threshold for its internal logic // Note: PushRouter uses active_decode_blocks_threshold for its internal logic
let threshold_value = worker_monitor let threshold_value = worker_monitor
......
...@@ -5,7 +5,7 @@ use std::time::Instant; ...@@ -5,7 +5,7 @@ use std::time::Instant;
use anyhow::Result; use anyhow::Result;
use dynamo_kv_router::{ use dynamo_kv_router::{
config::{KvRouterConfig, RouterConfigOverride}, config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
indexer::KvRouterError, indexer::KvRouterError,
protocols::KV_EVENT_SUBJECT, protocols::KV_EVENT_SUBJECT,
protocols::{ protocols::{
...@@ -126,7 +126,7 @@ where ...@@ -126,7 +126,7 @@ where
pub async fn new( pub async fn new(
endpoint: Endpoint, endpoint: Endpoint,
client: Client, client: Client,
mut workers_with_configs: RuntimeConfigWatch, workers_with_configs: RuntimeConfigWatch,
block_size: u32, block_size: u32,
selector: Sel, selector: Sel,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
...@@ -138,17 +138,19 @@ where ...@@ -138,17 +138,19 @@ where
kv_router_config.validate()?; kv_router_config.validate()?;
let component = endpoint.component(); let component = endpoint.component();
let cancellation_token = component.drt().primary_token(); let cancellation_token = component.drt().primary_token();
let min_initial_workers = min_initial_workers_from_env()?;
let indexer = Indexer::new(component, &kv_router_config, block_size, model_name).await?; let indexer = Indexer::new(component, &kv_router_config, block_size, model_name).await?;
if !kv_router_config.skip_initial_worker_wait { if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait {
let _ = workers_with_configs let mut startup_watch = workers_with_configs.clone();
.wait_for(|m| m.len() >= kv_router_config.min_initial_workers) let _ = startup_watch
.wait_for(|m| m.len() >= min_initial_workers)
.await .await
.map_err(|_| { .map_err(|_| {
anyhow::anyhow!( anyhow::anyhow!(
"runtime config watch closed before {} workers appeared", "runtime config watch closed before {} workers appeared",
kv_router_config.min_initial_workers min_initial_workers
) )
})?; })?;
} }
......
...@@ -217,44 +217,38 @@ mod tests { ...@@ -217,44 +217,38 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
seq_manager_1 seq_manager_1.add_request(SequenceRequest {
.add_request(SequenceRequest { request_id: "request_0".to_string(),
request_id: "request_0".to_string(), token_sequence: Some(vec![0, 1, 2]),
token_sequence: Some(vec![0, 1, 2]), isl: 12,
isl: 12, overlap: 0,
overlap: 0, track_prefill_tokens: true,
track_prefill_tokens: true, expected_output_tokens: None,
expected_output_tokens: None, worker: WorkerWithDpRank::new(0, 0),
worker: WorkerWithDpRank::new(0, 0), lora_name: None,
lora_name: None, })?;
})
.await?; seq_manager_1.add_request(SequenceRequest {
request_id: "request_1".to_string(),
seq_manager_1 token_sequence: Some(vec![3, 4]),
.add_request(SequenceRequest { isl: 8,
request_id: "request_1".to_string(), overlap: 0,
token_sequence: Some(vec![3, 4]), track_prefill_tokens: true,
isl: 8, expected_output_tokens: None,
overlap: 0, worker: WorkerWithDpRank::new(0, 1),
track_prefill_tokens: true, lora_name: None,
expected_output_tokens: None, })?;
worker: WorkerWithDpRank::new(0, 1),
lora_name: None, seq_manager_2.add_request(SequenceRequest {
}) request_id: "request_2".to_string(),
.await?; token_sequence: Some(vec![0, 1, 2, 3]),
isl: 16,
seq_manager_2 overlap: 0,
.add_request(SequenceRequest { track_prefill_tokens: true,
request_id: "request_2".to_string(), expected_output_tokens: None,
token_sequence: Some(vec![0, 1, 2, 3]), worker: WorkerWithDpRank::new(1, 0),
isl: 16, lora_name: None,
overlap: 0, })?;
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::new(1, 0),
lora_name: None,
})
.await?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
...@@ -290,10 +284,10 @@ mod tests { ...@@ -290,10 +284,10 @@ mod tests {
"Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)" "Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)"
); );
seq_manager_1.free(&"request_2".to_string()).await?; seq_manager_1.free(&"request_2".to_string())?;
seq_manager_2.free(&"request_0".to_string()).await?; seq_manager_2.free(&"request_0".to_string())?;
seq_manager_2.free(&"request_1".to_string()).await?; seq_manager_2.free(&"request_1".to_string())?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
...@@ -370,44 +364,38 @@ mod tests { ...@@ -370,44 +364,38 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
seq_manager_1 seq_manager_1.add_request(SequenceRequest {
.add_request(SequenceRequest { request_id: "request_0".to_string(),
request_id: "request_0".to_string(), token_sequence: None,
token_sequence: None, isl: 12,
isl: 12, overlap: 0,
overlap: 0, track_prefill_tokens: true,
track_prefill_tokens: true, expected_output_tokens: None,
expected_output_tokens: None, worker: WorkerWithDpRank::from_worker_id(0),
worker: WorkerWithDpRank::from_worker_id(0), lora_name: None,
lora_name: None, })?;
})
.await?; seq_manager_1.add_request(SequenceRequest {
request_id: "request_1".to_string(),
seq_manager_1 token_sequence: None,
.add_request(SequenceRequest { isl: 8,
request_id: "request_1".to_string(), overlap: 0,
token_sequence: None, track_prefill_tokens: true,
isl: 8, expected_output_tokens: None,
overlap: 0, worker: WorkerWithDpRank::from_worker_id(1),
track_prefill_tokens: true, lora_name: None,
expected_output_tokens: None, })?;
worker: WorkerWithDpRank::from_worker_id(1),
lora_name: None, seq_manager_2.add_request(SequenceRequest {
}) request_id: "request_2".to_string(),
.await?; token_sequence: None,
isl: 16,
seq_manager_2 overlap: 0,
.add_request(SequenceRequest { track_prefill_tokens: true,
request_id: "request_2".to_string(), expected_output_tokens: None,
token_sequence: None, worker: WorkerWithDpRank::from_worker_id(2),
isl: 16, lora_name: None,
overlap: 0, })?;
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None,
})
.await?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
...@@ -430,19 +418,13 @@ mod tests { ...@@ -430,19 +418,13 @@ mod tests {
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)" "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
); );
seq_manager_1 seq_manager_1.mark_prefill_completed(&"request_2".to_string())?;
.mark_prefill_completed(&"request_2".to_string()) seq_manager_1.free(&"request_2".to_string())?;
.await?;
seq_manager_1.free(&"request_2".to_string()).await?; seq_manager_2.mark_prefill_completed(&"request_0".to_string())?;
seq_manager_2.mark_prefill_completed(&"request_1".to_string())?;
seq_manager_2 seq_manager_2.free(&"request_0".to_string())?;
.mark_prefill_completed(&"request_0".to_string()) seq_manager_2.free(&"request_1".to_string())?;
.await?;
seq_manager_2
.mark_prefill_completed(&"request_1".to_string())
.await?;
seq_manager_2.free(&"request_0".to_string()).await?;
seq_manager_2.free(&"request_1".to_string()).await?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
......
...@@ -396,7 +396,7 @@ impl MockEngine { ...@@ -396,7 +396,7 @@ impl MockEngine {
let mut senders = Vec::with_capacity(args.dp_size as usize); let mut senders = Vec::with_capacity(args.dp_size as usize);
for dp_rank in 0..args.dp_size { for dp_rank in 0..args.dp_size {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let (kv_event_publishers, relay_publisher): ( let (kv_event_publishers, relay_publisher): (
KvEventPublishers, KvEventPublishers,
...@@ -499,12 +499,14 @@ impl MockEngine { ...@@ -499,12 +499,14 @@ impl MockEngine {
loop { loop {
tokio::select! { tokio::select! {
signal_result = output_rx.recv() => { signal_result = output_rx.recv() => {
let Some(signal) = signal_result else { let Some(output_batch) = signal_result else {
break; // Channel closed break; // Channel closed
}; };
if let Some(request_tx) = active_requests_clone.get(&signal.uuid) { for signal in output_batch {
let _ = request_tx.send(signal); if let Some(request_tx) = active_requests_clone.get(&signal.uuid) {
let _ = request_tx.send(signal);
}
} }
} }
_ = cancel_token_cloned.cancelled() => { _ = cancel_token_cloned.cancelled() => {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment