Unverified Commit 9e33e3fa authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: clean ups in kv_router.rs (#6028)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 07db5895
......@@ -2464,6 +2464,7 @@ dependencies = [
"tokio-util",
"tracing",
"uuid 1.18.1",
"validator",
]
[[package]]
......
......@@ -178,24 +178,23 @@ impl Flags {
}
pub fn router_config(&self) -> RouterConfig {
RouterConfig::new(
self.router_mode.into(),
KvRouterConfig::new(
self.kv_overlap_score_weight,
self.router_temperature,
self.use_kv_events,
self.router_replica_sync,
self.router_track_active_blocks,
None, // track_output_blocks
// defaulting below args (no longer maintaining new flags for dynamo-run)
None, // assume_kv_reuse
None,
None,
None,
None,
None,
),
)
let mut cfg = KvRouterConfig::default();
if let Some(w) = self.kv_overlap_score_weight {
cfg.overlap_score_weight = w;
}
if let Some(t) = self.router_temperature {
cfg.router_temperature = t;
}
if let Some(v) = self.use_kv_events {
cfg.use_kv_events = v;
}
if let Some(v) = self.router_replica_sync {
cfg.router_replica_sync = v;
}
if let Some(v) = self.router_track_active_blocks {
cfg.router_track_active_blocks = v;
}
RouterConfig::new(self.router_mode.into(), cfg)
}
/// Load extra engine arguments from a JSON file
......
......@@ -497,20 +497,13 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
};
let kv_router_config = if use_kv_routing {
Some(KvRouterConfig::new(
(overlap_score_weight >= 0.0).then_some(overlap_score_weight),
(router_temperature >= 0.0).then_some(router_temperature),
Some(use_kv_events),
Some(router_replica_sync),
None, // track_active_blocks
None, // track_output_blocks
None, // assume_kv_reuse
None, // router_snapshot_threshold
None, // router_reset_states
None, // router_ttl_secs
None, // router_max_tree_size
None, // router_prune_target_ratio
))
Some(KvRouterConfig {
overlap_score_weight,
router_temperature,
use_kv_events,
router_replica_sync,
..KvRouterConfig::default()
})
} else {
None
};
......
......@@ -24,6 +24,7 @@ use futures::stream::{self, StreamExt};
use rand::Rng;
use serde::{Deserialize, Serialize};
use serde_json::json;
use validator::Validate;
// Re-export from dynamo-kv-router crate
pub use dynamo_kv_router::approx;
......@@ -123,20 +124,23 @@ pub trait WorkerSelector {
}
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize, Validate)]
pub struct RouterConfigOverride {
#[builder(default)]
pub overlap_score_weight: Option<f64>,
#[builder(default)]
#[validate(range(min = 0.0))]
pub router_temperature: Option<f64>,
}
/// KV Router configuration parameters
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Validate)]
pub struct KvRouterConfig {
#[validate(range(min = 0.0))]
pub overlap_score_weight: f64,
#[validate(range(min = 0.0))]
pub router_temperature: f64,
pub use_kv_events: bool,
......@@ -157,18 +161,22 @@ pub struct KvRouterConfig {
pub router_assume_kv_reuse: bool,
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
#[validate(range(min = 1))]
pub router_snapshot_threshold: Option<u32>,
/// Whether to reset the router state on startup (default: false)
pub router_reset_states: bool,
/// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
#[validate(range(min = 0.0))]
pub router_ttl_secs: f64,
/// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
#[validate(range(min = 1))]
pub router_max_tree_size: usize,
/// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
#[validate(range(min = 0.0, max = 1.0))]
pub router_prune_target_ratio: f64,
}
......@@ -192,44 +200,6 @@ impl Default for KvRouterConfig {
}
impl KvRouterConfig {
/// Create a new KvRouterConfig with optional weight values.
/// If a weight is None, the default value will be used.
#[allow(clippy::too_many_arguments)]
pub fn new(
overlap_score_weight: Option<f64>,
temperature: Option<f64>,
use_kv_events: Option<bool>,
replica_sync: Option<bool>,
track_active_blocks: Option<bool>,
track_output_blocks: Option<bool>,
assume_kv_reuse: Option<bool>,
router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>,
router_ttl_secs: Option<f64>,
router_max_tree_size: Option<usize>,
router_prune_target_ratio: Option<f64>,
) -> Self {
let default = Self::default();
Self {
overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
router_temperature: temperature.unwrap_or(default.router_temperature),
use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
router_track_active_blocks: track_active_blocks
.unwrap_or(default.router_track_active_blocks),
router_track_output_blocks: track_output_blocks
.unwrap_or(default.router_track_output_blocks),
router_assume_kv_reuse: assume_kv_reuse.unwrap_or(default.router_assume_kv_reuse),
router_snapshot_threshold: router_snapshot_threshold
.unwrap_or(default.router_snapshot_threshold),
router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
router_prune_target_ratio: router_prune_target_ratio
.unwrap_or(default.router_prune_target_ratio),
}
}
/// Compute sequence hashes for active block tracking based on configuration.
///
/// Returns:
......@@ -347,6 +317,7 @@ impl KvRouter {
worker_type: &'static str,
) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default();
kv_router_config.validate()?;
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
......@@ -509,9 +480,8 @@ impl KvRouter {
#[cfg(feature = "bench")]
let start = Instant::now();
// Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() {
panic!("context_id must be provided if update_states is true");
anyhow::bail!("context_id must be provided when update_states is true");
}
let isl_tokens = tokens.len();
......@@ -784,12 +754,12 @@ impl KvPushRouter {
handle_local_updates: bool,
) -> Result<WorkerSelection, Error> {
let routing = request.routing.as_ref();
// Extract LORA name from routing hints
let lora_name = routing.and_then(|r| r.lora_name.clone());
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
// Get pre-selected worker based on phase, with backend_instance_id as fallback
let Some(id) = (match phase {
let preselected_id = match phase {
RequestPhase::Prefill => {
routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
}
......@@ -797,9 +767,9 @@ impl KvPushRouter {
routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
}
RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
}) else {
// No preselected worker - find the best match
// Don't update states if this is a query-only request
};
let Some(id) = preselected_id else {
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
......@@ -818,8 +788,6 @@ impl KvPushRouter {
});
};
// Route to pre-selected or explicitly specified worker
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
tracing::debug!(
worker_id = id,
dp_rank = dp_rank,
......@@ -827,20 +795,12 @@ impl KvPushRouter {
"Routing to specified worker"
);
// Compute actual overlap blocks by querying the indexer
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = self
.chooser
.get_overlap_blocks(&request.token_ids, worker)
.await?;
// Extract expected_output_tokens from routing hints
let expected_output_tokens = request
.routing
.as_ref()
.and_then(|r| r.expected_output_tokens);
// Perform add_request if this router handles local updates
if !is_query_only && handle_local_updates {
self.chooser
.add_request(
......
......@@ -73,6 +73,7 @@ pub struct MediaLoader {
media_decoder: MediaDecoder,
#[allow(dead_code)]
http_client: reqwest::Client,
#[allow(dead_code)]
media_fetcher: MediaFetcher,
#[cfg(feature = "media-nixl")]
nixl_agent: NixlAgent,
......
......@@ -17,6 +17,7 @@ dynamo-kv-router = { workspace = true }
# workspace
anyhow = { workspace = true }
validator = { workspace = true }
dashmap = { workspace = true }
derive_builder = { workspace = true }
derive-getters = { workspace = true }
......
......@@ -7,6 +7,7 @@ use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use uuid::Uuid;
use validator::Validate;
use crate::perf_model::PerfModel;
use dynamo_kv_router::protocols::KvCacheEvent;
......@@ -83,21 +84,25 @@ pub enum WorkerType {
}
/// Configuration arguments for MockVllmEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
#[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)]
#[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs {
#[builder(default = "16384")]
#[validate(range(min = 1))]
pub num_gpu_blocks: usize,
#[builder(default = "64")]
#[validate(range(min = 2))]
pub block_size: usize,
// This was 1024 in the past but reverted back to 256
#[builder(default = Some(256))]
#[validate(range(min = 1))]
pub max_num_seqs: Option<usize>,
// default for open api server, for llm class it's 16384
#[builder(default = Some(8192))]
#[validate(range(min = 1))]
pub max_num_batched_tokens: Option<usize>,
#[builder(default = true)]
......@@ -107,16 +112,20 @@ pub struct MockEngineArgs {
pub enable_chunked_prefill: bool,
#[builder(default = "0.01")]
#[validate(range(min = 0.0, max = 1.0))]
pub watermark: f64,
#[builder(default = "1.0")]
#[validate(range(min = 0.0))]
pub speedup_ratio: f64,
#[builder(default = "1")]
#[validate(range(min = 1))]
pub dp_size: u32,
/// Optional startup time in seconds to simulate engine initialization delay
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub startup_time: Option<f64>,
/// Worker type for disaggregated serving (Aggregated, Prefill, or Decode)
......
......@@ -45,6 +45,7 @@ use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use validator::Validate;
/// Simple metrics struct for mocker's internal use
#[derive(Clone, Default, Debug)]
......@@ -259,12 +260,7 @@ impl Scheduler {
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
cancellation_token: Option<CancellationToken>,
) -> Self {
// Assert speedup_ratio is non-negative (0 means infinite speedup)
assert!(
args.speedup_ratio >= 0.0,
"speedup_ratio must be >= 0 (0 means infinite speedup), got: {}",
args.speedup_ratio
);
args.validate().expect("invalid MockEngineArgs");
// Create channel for request handling
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
......
......@@ -6,6 +6,7 @@ use derive_getters::Getters;
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{TokenBlockSequence, Tokens};
use rand::random;
use validator::Validate;
/// Create unique blocks from a TokenBlockSequence
fn create_unique_blocks_from_sequence(
......@@ -34,13 +35,14 @@ fn create_unique_blocks_from_sequence(
/// A sequence that is actively being built, with the ability to add tokens and commit to hashes
/// TODO: reuse tokens
#[derive(Debug, Getters)]
#[derive(Debug, Getters, Validate)]
pub struct ActiveSequence {
unique_blocks: Vec<UniqueBlock>,
tokens: TokenBlockSequence,
#[getter(copy)]
#[validate(range(min = 2))]
block_size: usize,
#[getter(copy)]
......@@ -67,7 +69,6 @@ impl ActiveSequence {
enable_prefix_caching: bool,
) -> Self {
let block_size = block_size.unwrap_or(64);
assert!(block_size > 1, "block_size must be greater than 1");
let num_input_tokens = tokens.len();
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
......@@ -76,7 +77,7 @@ impl ActiveSequence {
let block_hashes = tokens.blocks().iter().map(|b| b.block_hash()).collect();
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone(), block_hashes));
Self {
let seq = Self {
unique_blocks,
tokens,
block_size,
......@@ -85,7 +86,9 @@ impl ActiveSequence {
num_input_tokens,
creation_signal,
enable_prefix_caching,
}
};
seq.validate().expect("invalid ActiveSequence");
seq
}
pub fn extra_tokens(&self) -> u32 {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment