Unverified Commit 8379b0cd authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: kv router should route to available instances (#4225)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 8a63d9ce
......@@ -950,7 +950,8 @@ pub async fn create_worker_selection_pipeline_chat(
let component = distributed_runtime
.namespace(namespace)?
.component(component_name)?;
let client = component.endpoint(GENERATE_ENDPOINT).client().await?;
let endpoint = component.endpoint(GENERATE_ENDPOINT);
let client = endpoint.client().await?;
// Discover the model card by searching all instances with this model name
tracing::debug!("Looking for model: {}", model_name);
......@@ -980,7 +981,7 @@ pub async fn create_worker_selection_pipeline_chat(
let chooser = if router_mode == RouterMode::KV {
Some(
model_manager
.kv_chooser_for(&component, card.kv_cache_block_size, kv_router_config)
.kv_chooser_for(&endpoint, card.kv_cache_block_size, kv_router_config)
.await?,
)
} else {
......
......@@ -998,13 +998,10 @@ async fn create_kv_router_from_endpoint(
block_size: usize,
kv_router_config: Option<llm_rs::kv_router::KvRouterConfig>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
// Get component from endpoint
let component = endpoint.inner.component();
// Create ModelManager and use it to create KvRouter (ensures registration)
let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
let kv_router = model_manager
.kv_chooser_for(component, block_size as u32, kv_router_config)
.kv_chooser_for(&endpoint.inner, block_size as u32, kv_router_config)
.await
.map_err(to_pyerr)?;
......
......@@ -10,10 +10,7 @@ use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;
use dynamo_runtime::prelude::DistributedRuntimeProvider;
use dynamo_runtime::{
component::{Component, Endpoint},
storage::key_value_store::Key,
};
use dynamo_runtime::{component::Endpoint, storage::key_value_store::Key};
use crate::{
discovery::KV_ROUTERS_ROOT_PATH,
......@@ -292,32 +289,33 @@ impl ModelManager {
pub async fn kv_chooser_for(
&self,
component: &Component,
endpoint: &Endpoint,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> {
let service_name = component.service_name();
let endpoint_path = endpoint.path();
if let Some(kv_chooser) = self.get_kv_chooser(&service_name) {
if let Some(kv_chooser) = self.get_kv_chooser(&endpoint_path) {
// Check if the existing router has a different block size
if kv_chooser.block_size() != kv_cache_block_size {
tracing::warn!(
component = %service_name,
endpoint = %endpoint_path,
existing_block_size = %kv_chooser.block_size(),
requested_block_size = %kv_cache_block_size,
"KV Router block size mismatch! Component is requesting a different kv_cache_block_size than the existing router. \
"KV Router block size mismatch! Endpoint is requesting a different kv_cache_block_size than the existing router. \
This will cause routing to fail silently. Consider using the same block size or restarting the router."
);
}
return Ok(kv_chooser);
}
let store = component.drt().store();
let client = endpoint.client().await?;
let store = endpoint.component().drt().store();
let router_bucket = store
.get_or_create_bucket(KV_ROUTERS_ROOT_PATH, None)
.await?;
let router_uuid = uuid::Uuid::new_v4();
let router_key = Key::from_raw(format!("{}/{router_uuid}", component.path()));
let router_key = Key::from_raw(format!("{}/{router_uuid}", endpoint.path()));
let json_router_config = serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?;
router_bucket
.insert(&router_key, json_router_config.into(), 0)
......@@ -325,7 +323,8 @@ impl ModelManager {
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new(
component.clone(),
endpoint.clone(),
client,
kv_cache_block_size,
Some(selector),
kv_router_config,
......@@ -335,7 +334,7 @@ impl ModelManager {
let new_kv_chooser = Arc::new(chooser);
self.kv_choosers
.lock()
.insert(service_name, new_kv_chooser.clone());
.insert(endpoint_path, new_kv_chooser.clone());
Ok(new_kv_chooser)
}
......
......@@ -370,10 +370,11 @@ impl ModelWatcher {
// A model that expects pre-processed requests meaning it's up to us whether we
// handle Chat or Completions requests, so handle whatever the model supports.
let endpoint = component.endpoint(&endpoint_id.name);
let kv_chooser = if self.router_mode == RouterMode::KV {
Some(
self.manager
.kv_chooser_for(&component, card.kv_cache_block_size, self.kv_router_config)
.kv_chooser_for(&endpoint, card.kv_cache_block_size, self.kv_router_config)
.await?,
)
} else {
......
......@@ -224,17 +224,27 @@ where
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let migration = Migration::from_mdc(card).into_operator();
// For KV routing, use the client from the chooser to ensure shared state
let router_client = if router_mode == RouterMode::KV {
let Some(ref chooser) = chooser else {
anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
};
chooser.client().clone()
} else {
client.clone()
};
// Create worker monitor only if busy_threshold is set
let worker_monitor = busy_threshold.map(|threshold| {
Arc::new(crate::discovery::KvWorkerMonitor::new(
Arc::new(client.clone()),
Arc::new(router_client.clone()),
threshold,
)) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>
});
let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client.clone(),
router_client,
router_mode,
busy_threshold,
worker_monitor,
......
......@@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::{
component::Component,
component::{Client, Endpoint},
discovery::{DiscoveryQuery, watch_and_extract_field},
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
......@@ -213,29 +213,32 @@ pub struct KvRouter {
kv_router_config: KvRouterConfig,
cancellation_token: tokio_util::sync::CancellationToken,
client: Client,
}
impl KvRouter {
pub async fn new(
component: Component,
endpoint: Endpoint,
client: Client,
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: Option<KvRouterConfig>,
consumer_uuid: String,
) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default();
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
let generate_endpoint = component.endpoint("generate");
let client = generate_endpoint.client().await?;
let instances_rx = client.instance_source.as_ref().clone();
let instance_ids_rx = client.instance_avail_watcher();
// Watch for runtime config updates via discovery interface
let discovery = component.drt().discovery();
let endpoint_id = endpoint.id();
let discovery_key = DiscoveryQuery::EndpointModels {
namespace: component.namespace().name().to_string(),
component: component.name().to_string(),
endpoint: "generate".to_string(),
namespace: endpoint_id.namespace.clone(),
component: endpoint_id.component.clone(),
endpoint: endpoint_id.name.clone(),
};
let discovery_stream = discovery
.list_and_watch(discovery_key, Some(cancellation_token.clone()))
......@@ -249,7 +252,7 @@ impl KvRouter {
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer::None
} else if kv_router_config.use_kv_events {
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component);
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
Indexer::KvIndexer(KvIndexer::new(
cancellation_token.clone(),
block_size,
......@@ -271,7 +274,7 @@ impl KvRouter {
let scheduler = KvScheduler::start(
component.clone(),
block_size,
instances_rx,
instance_ids_rx,
runtime_configs_rx,
selector,
kv_router_config.router_replica_sync,
......@@ -306,9 +309,15 @@ impl KvRouter {
block_size,
kv_router_config,
cancellation_token,
client,
})
}
/// Get a reference to the client used by this KvRouter
pub fn client(&self) -> &Client {
&self.client
}
/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking
......
......@@ -107,15 +107,16 @@ impl PrefillRouter {
"Activating prefill router"
);
let client = endpoint.client().await?;
let inner_router = if self.router_mode.is_kv_routing() {
// Create KV chooser using the component from the endpoint
// Create KV chooser using the endpoint
let kv_chooser = model_manager
.kv_chooser_for(endpoint.component(), kv_cache_block_size, kv_router_config)
.kv_chooser_for(&endpoint, kv_cache_block_size, kv_router_config)
.await?;
// Build the PushRouter for prefill with KV mode
// Extract client from kv_chooser to ensure shared state
let client = kv_chooser.client().clone();
// Build the PushRouter for prefill with KV mode using the shared client
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
RouterMode::KV,
......@@ -127,6 +128,9 @@ impl PrefillRouter {
// Wrap it in KvPushRouter
InnerPrefillRouter::KvRouter(Arc::new(KvPushRouter::new(push_router, kv_chooser)))
} else {
// Create client for simple router
let client = endpoint.client().await?;
// Create simple push router with the frontend's router mode
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
......
......@@ -3,7 +3,7 @@
use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::{Component, Instance};
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
use rand::Rng;
......@@ -96,27 +96,26 @@ impl KvScheduler {
pub async fn start(
component: Component,
block_size: u32,
instances_rx: watch::Receiver<Vec<Instance>>,
instance_ids_rx: watch::Receiver<Vec<u64>>,
runtime_configs_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_uuid: String,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone();
let instance_ids: Vec<u64> = instance_ids_rx.borrow().clone();
let runtime_configs: HashMap<WorkerId, ModelRuntimeConfig> =
runtime_configs_rx.borrow().clone();
// Create shared workers_with_configs wrapped in Arc<RwLock>
let workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>> = {
let mut initial_map = HashMap::new();
for instance in &instances {
let worker_id = instance.instance_id;
let config = runtime_configs.get(&worker_id).cloned();
for worker_id in &instance_ids {
let config = runtime_configs.get(worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
initial_map.insert(worker_id, config);
initial_map.insert(*worker_id, config);
}
Arc::new(RwLock::new(initial_map))
};
......@@ -132,7 +131,7 @@ impl KvScheduler {
// Spawn background task to monitor and update workers_with_configs
let workers_monitor = workers_with_configs.clone();
let slots_monitor = slots.clone();
let mut instances_monitor_rx = instances_rx.clone();
let mut instance_ids_monitor_rx = instance_ids_rx.clone();
let mut configs_monitor_rx = runtime_configs_rx.clone();
let monitor_cancel_token = component.drt().primary_token();
tokio::spawn(async move {
......@@ -144,9 +143,9 @@ impl KvScheduler {
tracing::trace!("workers monitoring task shutting down");
break;
}
result = instances_monitor_rx.changed() => {
result = instance_ids_monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("endpoint watch sender shutdown in monitor");
tracing::warn!("instance IDs watch sender shutdown in monitor");
break;
}
}
......@@ -159,18 +158,17 @@ impl KvScheduler {
}
// Get the latest values from both channels
let new_instances = instances_monitor_rx.borrow_and_update().clone();
let new_instance_ids = instance_ids_monitor_rx.borrow_and_update().clone();
let new_configs = configs_monitor_rx.borrow_and_update().clone();
// Build the new workers_with_configs map
let mut new_workers_with_configs = HashMap::new();
for instance in &new_instances {
let worker_id = instance.instance_id;
let config = new_configs.get(&worker_id).cloned();
for worker_id in &new_instance_ids {
let config = new_configs.get(worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
new_workers_with_configs.insert(worker_id, config);
new_workers_with_configs.insert(*worker_id, config);
}
// Update workers when instances change
......
......@@ -31,6 +31,10 @@ pub struct Client {
instance_avail: Arc<ArcSwap<Vec<u64>>>,
// These are the instance source ids less those reported as busy (above threshold)
instance_free: Arc<ArcSwap<Vec<u64>>>,
// Watch sender for available instance IDs (for sending updates)
instance_avail_tx: Arc<tokio::sync::watch::Sender<Vec<u64>>>,
// Watch receiver for available instance IDs (for cloning to external subscribers)
instance_avail_rx: tokio::sync::watch::Receiver<Vec<u64>>,
}
impl Client {
......@@ -46,11 +50,14 @@ impl Client {
endpoint.path()
);
let (avail_tx, avail_rx) = tokio::sync::watch::channel(vec![]);
let client = Client {
endpoint: endpoint.clone(),
instance_source: instance_source.clone(),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_avail_tx: Arc::new(avail_tx),
instance_avail_rx: avail_rx,
};
tracing::debug!(
"Client::new_dynamic: Starting instance source monitor for endpoint: {}",
......@@ -90,6 +97,11 @@ impl Client {
self.instance_free.load()
}
/// Get a watcher for available instance IDs
pub fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
self.instance_avail_rx.clone()
}
/// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
tracing::debug!(
......@@ -138,7 +150,10 @@ impl Client {
.iter()
.filter_map(|&id| if id == instance_id { None } else { Some(id) })
.collect::<Vec<_>>();
self.instance_avail.store(Arc::new(filtered));
self.instance_avail.store(Arc::new(filtered.clone()));
// Notify watch channel subscribers about the change
let _ = self.instance_avail_tx.send(filtered);
tracing::debug!("inhibiting instance {instance_id}");
}
......@@ -184,6 +199,9 @@ impl Client {
client.instance_avail.store(Arc::new(instance_ids.clone()));
client.instance_free.store(Arc::new(instance_ids.clone()));
// Send update to watch channel subscribers
let _ = client.instance_avail_tx.send(instance_ids);
tracing::debug!(
"monitor_instance_source: instance source updated, endpoint={}",
endpoint_path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment