Commit c4106e6a authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: kv aware router executable (#399)

parent 183941fa
...@@ -5166,6 +5166,20 @@ dependencies = [ ...@@ -5166,6 +5166,20 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "router"
version = "0.1.0"
dependencies = [
"clap",
"dynamo-llm",
"dynamo-runtime",
"rand 0.9.0",
"serde",
"serde_json",
"tokio",
"tracing",
]
[[package]] [[package]]
name = "rstest" name = "rstest"
version = "0.18.2" version = "0.18.2"
......
...@@ -465,8 +465,8 @@ impl PrometheusMetrics { ...@@ -465,8 +465,8 @@ impl PrometheusMetrics {
/// Update metrics with current values /// Update metrics with current values
fn update(&self, config: &LLMWorkerLoadCapacityConfig, processed: &ProcessedEndpoints) { fn update(&self, config: &LLMWorkerLoadCapacityConfig, processed: &ProcessedEndpoints) {
// Update per-worker metrics // Update per-worker metrics
for endpoint in processed.endpoints.iter() { for (worker_id, endpoint) in processed.endpoints.iter() {
let worker_id = endpoint.worker_id().to_string(); let worker_id = worker_id.to_string();
let metrics = endpoint.data.clone(); let metrics = endpoint.data.clone();
self.set_worker_gauge( self.set_worker_gauge(
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[package]
name = "router"
version.workspace = true
edition.workspace = true
description.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
keywords.workspace = true
[dependencies]
dynamo-runtime = { workspace = true}
dynamo-llm = { workspace = true}
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
clap = { version = "4.5", features = ["derive"] }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO(#400):
// Instead of passing in a block_size, we should get this data from the backend component's config.
// What changes need to be made:
// 1. Take as an argument the name of the backend component.
// 2. Update the backend component to produce a config in a standard location.
// 3. Update the KvRouter to read the config from the backend component.
use clap::Parser;
use dynamo_llm::kv_router::{
protocols::WorkerSelectionResult,
scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
KvRouter, WorkerSelector,
};
use dynamo_runtime::{
logging, pipeline::network::Ingress, DistributedRuntime, Result, Runtime, Worker,
};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Namespace for the distributed component
#[arg(long)]
namespace: String,
/// Component name for the service
#[arg(long, default_value = "kv_aware_router")]
component: String,
/// Block size for the router
#[arg(long)]
block_size: usize,
}
fn main() -> Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
let args = Args::parse();
let runtime = DistributedRuntime::from_settings(runtime).await?;
let component = runtime
.namespace(&args.namespace)?
.component(&args.component)?;
let selector = Box::new(CustomWorkerSelector::default());
let router = KvRouter::new(component.clone(), args.block_size, Some(selector)).await?;
let router = Ingress::for_engine(router)?;
component
.service_builder()
.create()
.await?
.endpoint("generate")
.endpoint_builder()
.handler(router)
.start()
.await
}
#[derive(Default)]
pub struct CustomWorkerSelector(DefaultWorkerSelector);
impl WorkerSelector for CustomWorkerSelector {
fn select_worker(
&self,
workers: &ProcessedEndpoints,
request: &SchedulingRequest,
block_size: usize,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
// customize logic here
// F12 into [DefaultWorkerSelector] to see the original logic
self.0.select_worker(workers, request, block_size)
}
}
...@@ -6,10 +6,8 @@ ...@@ -6,10 +6,8 @@
], ],
"settings": { "settings": {
"rust-analyzer.linkedProjects": [ "rust-analyzer.linkedProjects": [
"components/metrics/Cargo.toml", "Cargo.toml",
"launch/dynamo-run/Cargo.toml", "launch/dynamo-run/Cargo.toml",
"lib/llm/Cargo.toml",
"lib/runtime/Cargo.toml",
"lib/bindings/python/Cargo.toml" "lib/bindings/python/Cargo.toml"
], ],
"rust-analyzer.procMacro.enable": true, "rust-analyzer.procMacro.enable": true,
......
...@@ -30,17 +30,13 @@ pub(crate) struct KvRouter { ...@@ -30,17 +30,13 @@ pub(crate) struct KvRouter {
#[pymethods] #[pymethods]
impl KvRouter { impl KvRouter {
#[new] #[new]
// [FXIME] 'drt' can be obtained from 'component' fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
fn new(drt: DistributedRuntime, component: Component, kv_block_size: usize) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async { runtime.block_on(async {
let inner = llm_rs::kv_router::KvRouter::from_runtime( let inner =
drt.inner.clone(), llm_rs::kv_router::KvRouter::new(component.inner.clone(), kv_block_size, None)
component.inner.clone(), .await
kv_block_size, .map_err(to_pyerr)?;
)
.await
.map_err(to_pyerr)?;
Ok(Self { inner }) Ok(Self { inner })
}) })
} }
...@@ -376,8 +372,8 @@ impl KvMetricsAggregator { ...@@ -376,8 +372,8 @@ impl KvMetricsAggregator {
let endpoint_kv_metrics = endpoints let endpoint_kv_metrics = endpoints
.endpoints .endpoints
.iter() .iter()
.map(|x| EndpointKvMetrics { .map(|(worker_id, x)| EndpointKvMetrics {
worker_id: x.worker_id(), worker_id: *worker_id,
request_active_slots: x.data.request_active_slots, request_active_slots: x.data.request_active_slots,
request_total_slots: x.data.request_total_slots, request_total_slots: x.data.request_total_slots,
kv_active_blocks: x.data.kv_active_blocks, kv_active_blocks: x.data.kv_active_blocks,
......
...@@ -14,11 +14,17 @@ ...@@ -14,11 +14,17 @@
// limitations under the License. // limitations under the License.
use anyhow::Result; use anyhow::Result;
use dynamo_runtime::{component::Component, component::Namespace, DistributedRuntime}; use dynamo_runtime::{
use futures::stream::StreamExt; component::Component,
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream,
SingleIn,
},
prelude::*,
protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
use std::sync::Arc; use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing;
pub mod indexer; pub mod indexer;
pub mod metrics_aggregator; pub mod metrics_aggregator;
...@@ -27,14 +33,18 @@ pub mod publisher; ...@@ -27,14 +33,18 @@ pub mod publisher;
pub mod scheduler; pub mod scheduler;
pub mod scoring; pub mod scoring;
use crate::kv_router::{ use crate::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, kv_router::{
metrics_aggregator::collect_endpoints_task, indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
scheduler::KvScheduler, metrics_aggregator::KvMetricsAggregator,
scoring::ProcessedEndpoints, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
},
tokens::Tokens,
}; };
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; use dynamo_runtime::traits::events::EventSubscriber;
// [gluo TODO] shouldn't need to be public // [gluo TODO] shouldn't need to be public
// this should be discovered from the component // this should be discovered from the component
...@@ -42,49 +52,40 @@ pub const KV_EVENT_SUBJECT: &str = "kv_events"; ...@@ -42,49 +52,40 @@ pub const KV_EVENT_SUBJECT: &str = "kv_events";
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate"; pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
pub struct KvRouter { /// A trait that users can implement to define custom selection logic
// properties of request plane pub trait WorkerSelector {
// maybe rolled up into the generic object or not fn select_worker(
service_name: String, &self,
workers: &ProcessedEndpoints,
cancellation_token: CancellationToken, request: &SchedulingRequest,
block_size: usize,
#[allow(dead_code)] ) -> Result<WorkerSelectionResult, KvSchedulerError>;
scheduler: KvScheduler, }
pub struct KvRouter {
indexer: KvIndexer, indexer: KvIndexer,
scheduler: KvScheduler,
block_size: usize,
} }
impl KvRouter { impl KvRouter {
pub async fn from_runtime(
runtime: DistributedRuntime,
component: Component,
kv_block_size: usize,
) -> Result<Arc<Self>> {
let namespace = runtime.namespace(component.namespace().name())?;
tracing::info!("Component Namespace {}", component.namespace());
tracing::info!("Component Service Name {}", component.service_name());
tracing::info!("KV Subject {}.{}", component.subject(), KV_EVENT_SUBJECT);
Self::new(component, namespace, kv_block_size).await
}
pub async fn new( pub async fn new(
component: Component, component: Component,
namespace: Namespace, block_size: usize,
kv_block_size: usize, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let cancellation_token = CancellationToken::new(); let cancellation_token = component.drt().primary_lease().primary_token();
let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);
let metrics_aggregator =
tokio::spawn(collect_endpoints_task( KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
component.clone(), let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
ep_tx, let scheduler = KvScheduler::start(
cancellation_token.clone(), component.namespace().clone(),
)); block_size,
metrics_aggregator.endpoints_watcher(),
let indexer = KvIndexer::new(cancellation_token.clone(), kv_block_size); selector,
let scheduler = KvScheduler::start(ep_rx, namespace, kv_block_size).await?; )
.await?;
// [gluo TODO] try subscribe_with_type::<RouterEvent>, // [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different. // error checking below will be different.
...@@ -112,21 +113,12 @@ impl KvRouter { ...@@ -112,21 +113,12 @@ impl KvRouter {
}); });
Ok(Arc::new(Self { Ok(Arc::new(Self {
service_name: component.service_name(),
cancellation_token,
scheduler, scheduler,
indexer, indexer,
block_size,
})) }))
} }
pub fn cancellation_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
pub fn service_name(&self) -> &str {
&self.service_name
}
// [TODO] indexer needs to take 'lora_id' as parameter // [TODO] indexer needs to take 'lora_id' as parameter
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> { pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
// Extracting part of the code in KvRouter::generate() for only // Extracting part of the code in KvRouter::generate() for only
...@@ -141,3 +133,32 @@ impl KvRouter { ...@@ -141,3 +133,32 @@ impl KvRouter {
Ok(worker_id) Ok(worker_id)
} }
} }
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
async fn generate(
&self,
request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts();
let isl_tokens = request.tokens.len();
let block_size = self.block_size;
// Compute the block hashes in a blocking task
let local_block_hashes: Vec<LocalBlockHash> = tokio::task::spawn_blocking(move || {
Tokens::compute_block_hash(&request.tokens, block_size)
.into_iter()
.map(LocalBlockHash)
.collect()
})
.await?;
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
let response = RouterResponse { worker_id };
let response = Annotated::from_data(response);
let stream = stream::iter(vec![response]);
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::{Arc, Mutex};
pub use crate::kv_router::protocols::ForwardPassMetrics; pub use crate::kv_router::protocols::ForwardPassMetrics;
use crate::kv_router::KV_METRICS_ENDPOINT; use crate::kv_router::KV_METRICS_ENDPOINT;
...@@ -22,61 +20,36 @@ use crate::kv_router::scheduler::Endpoint; ...@@ -22,61 +20,36 @@ use crate::kv_router::scheduler::Endpoint;
use crate::kv_router::ProcessedEndpoints; use crate::kv_router::ProcessedEndpoints;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result};
use tokio::sync::watch;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
pub struct KvMetricsAggregator { pub struct KvMetricsAggregator {
pub service_name: String, pub service_name: String,
pub endpoints: Arc<Mutex<ProcessedEndpoints>>, pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
} }
impl KvMetricsAggregator { impl KvMetricsAggregator {
pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self { pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
let (ep_tx, mut ep_rx) = tokio::sync::mpsc::channel(128); let (watch_tx, watch_rx) = watch::channel(ProcessedEndpoints::default());
tokio::spawn(collect_endpoints_task( tokio::spawn(collect_endpoints_task(
component.clone(), component.clone(),
ep_tx, watch_tx,
cancellation_token.clone(), cancellation_token.clone(),
)); ));
tracing::trace!("awaiting the start of the background endpoint subscriber");
let endpoints = Arc::new(Mutex::new(ProcessedEndpoints::default()));
let endpoints_clone = endpoints.clone();
tokio::spawn(async move {
tracing::debug!("scheduler background task started");
loop {
match ep_rx.recv().await {
Some(endpoints) => match endpoints_clone.lock() {
Ok(mut shared_endpoint) => {
*shared_endpoint = endpoints;
}
Err(e) => {
tracing::error!("Failed to acquire lock on endpoints: {:?}", e);
}
},
None => {
tracing::warn!("endpoint subscriber shutdown");
break;
}
};
}
tracing::trace!("background endpoint subscriber shutting down");
});
Self { Self {
service_name: component.service_name(), service_name: component.service_name(),
endpoints, endpoints_rx: watch_rx,
} }
} }
pub fn get_endpoints(&self) -> ProcessedEndpoints { pub fn get_endpoints(&self) -> ProcessedEndpoints {
match self.endpoints.lock() { self.endpoints_rx.borrow().clone()
Ok(endpoints) => endpoints.clone(), }
Err(e) => {
tracing::error!("Failed to acquire lock on endpoints: {:?}", e); pub fn endpoints_watcher(&self) -> watch::Receiver<ProcessedEndpoints> {
ProcessedEndpoints::default() self.endpoints_rx.clone()
}
}
} }
} }
...@@ -108,7 +81,7 @@ pub async fn collect_endpoints( ...@@ -108,7 +81,7 @@ pub async fn collect_endpoints(
pub async fn collect_endpoints_task( pub async fn collect_endpoints_task(
component: Component, component: Component,
ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>, watch_tx: watch::Sender<ProcessedEndpoints>,
cancel: CancellationToken, cancel: CancellationToken,
) { ) {
let backoff_delay = Duration::from_millis(100); let backoff_delay = Duration::from_millis(100);
...@@ -161,7 +134,8 @@ pub async fn collect_endpoints_task( ...@@ -161,7 +134,8 @@ pub async fn collect_endpoints_task(
); );
let processed = ProcessedEndpoints::new(endpoints); let processed = ProcessedEndpoints::new(endpoints);
if ep_tx.send(processed).await.is_err() {
if watch_tx.send(processed).is_err() {
tracing::trace!("failed to send processed endpoints; shutting down"); tracing::trace!("failed to send processed endpoints; shutting down");
break; break;
} }
......
...@@ -13,8 +13,32 @@ ...@@ -13,8 +13,32 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::tokens::Token;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RouterRequest {
pub tokens: Vec<Token>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RouterResponse {
pub worker_id: i64,
}
#[derive(Debug)]
pub struct WorkerSelectionResult {
/// The worker id of the selected worker
pub worker_id: i64,
/// The total number of blocks required to prefill the request
pub required_blocks: u64,
/// The number of blocks that the selected worker may already have cached.
/// This is not a guarantee, but an estimate.
pub overlap_blocks: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ForwardPassMetrics { pub struct ForwardPassMetrics {
pub request_active_slots: u64, pub request_active_slots: u64,
......
...@@ -15,15 +15,19 @@ ...@@ -15,15 +15,19 @@
use dynamo_runtime::component::Namespace; use dynamo_runtime::component::Namespace;
use dynamo_runtime::traits::events::EventPublisher; use dynamo_runtime::traits::events::EventPublisher;
use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use std::cmp::min; use std::collections::HashMap;
use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::ForwardPassMetrics; pub use crate::kv_router::protocols::ForwardPassMetrics;
use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::scoring::ProcessedEndpoints;
use crate::kv_router::KV_HIT_RATE_SUBJECT; use crate::kv_router::KV_HIT_RATE_SUBJECT;
use super::protocols::WorkerSelectionResult;
use super::WorkerSelector;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent { pub struct KVHitRateEvent {
pub worker_id: i64, pub worker_id: i64,
...@@ -68,8 +72,8 @@ impl Endpoint { ...@@ -68,8 +72,8 @@ impl Endpoint {
} }
pub struct SchedulingRequest { pub struct SchedulingRequest {
isl_tokens: usize, pub isl_tokens: usize,
overlap: OverlapScores, pub overlap: OverlapScores,
resp_tx: tokio::sync::oneshot::Sender<i64>, resp_tx: tokio::sync::oneshot::Sender<i64>,
} }
...@@ -87,24 +91,16 @@ pub struct KvScheduler { ...@@ -87,24 +91,16 @@ pub struct KvScheduler {
impl KvScheduler { impl KvScheduler {
pub async fn start( pub async fn start(
endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>,
ns: Namespace, ns: Namespace,
kv_block_size: usize, block_size: usize,
endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector));
let mut endpoints_rx = endpoints_rx; let mut endpoints_rx = endpoints_rx;
let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone();
tracing::trace!("awaiting the start of the background endpoint subscriber");
let mut endpoints = match endpoints_rx.recv().await {
Some(endpoints) => endpoints,
None => {
return Err(KvSchedulerError::SubscriberShutdown);
}
};
// Channel to asynchronously publish metric events on
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>(); let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
// Publisher task
tokio::spawn(async move { tokio::spawn(async move {
let mut event_rx = event_rx; let mut event_rx = event_rx;
while let Some(event) = event_rx.recv().await { while let Some(event) = event_rx.recv().await {
...@@ -115,7 +111,7 @@ impl KvScheduler { ...@@ -115,7 +111,7 @@ impl KvScheduler {
}); });
// Channel to accept new scheduling requests // Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16); let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
tracing::debug!("scheduler starting"); tracing::debug!("scheduler starting");
// Background task to handle scheduling requests // Background task to handle scheduling requests
tokio::spawn(async move { tokio::spawn(async move {
...@@ -140,37 +136,33 @@ impl KvScheduler { ...@@ -140,37 +136,33 @@ impl KvScheduler {
} }
} }
new_endpoints = endpoints_rx.recv() => { _ = endpoints_rx.changed() => {
match new_endpoints { endpoints = endpoints_rx.borrow_and_update().clone();
Some(new_endpoints) => { continue 'outer;
tracing::trace!("updated endpoints");
endpoints = new_endpoints;
continue 'outer;
}
None => {
tracing::trace!("endpoint subscriber shutdown");
break 'outer;
}
}
} }
}; };
tracing::debug!("selected"); tracing::debug!("selected");
loop { loop {
match select_worker(endpoints.borrow_mut(), &request, &event_tx, kv_block_size) match selector.select_worker(&endpoints, &request, block_size) {
{ Ok(selection) => {
Ok(worker_id) => { let worker_id = process_worker_selection(
endpoints.borrow_mut(),
selection,
&event_tx,
);
request.respond(worker_id); request.respond(worker_id);
continue 'outer; continue 'outer;
} }
Err(KvSchedulerError::AllWorkersBusy) => { Err(KvSchedulerError::AllWorkersBusy) => {
tracing::trace!("all workers busy; waiting for more capacity"); tracing::trace!("all workers busy; waiting for more capacity");
endpoints = match endpoints_rx.recv().await { match endpoints_rx.changed().await {
Some(endpoints) => endpoints, Ok(_) => {}
None => { Err(e) => {
tracing::trace!("endpoint subscriber shutdown"); tracing::error!("error waiting for endpoints change: {:?}", e);
break 'outer; break 'outer;
} }
}; };
endpoints = endpoints_rx.borrow_and_update().clone();
} }
Err(e) => { Err(e) => {
tracing::error!("error scheduling request: {:?}", e); tracing::error!("error scheduling request: {:?}", e);
...@@ -212,104 +204,137 @@ impl KvScheduler { ...@@ -212,104 +204,137 @@ impl KvScheduler {
} }
} }
pub fn select_worker( // This becomes the driver function that handles the selection result
pub fn process_worker_selection(
workers: &mut ProcessedEndpoints, workers: &mut ProcessedEndpoints,
request: &SchedulingRequest, selection: WorkerSelectionResult,
event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>, event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
kv_block_size: usize, ) -> i64 {
) -> Result<i64, KvSchedulerError> { let worker = workers
// balance mode prioritizes balancing load across workers .endpoints
let balance_threshold: f64 = 0.1; .get_mut(&selection.worker_id)
let balance_mode = workers.load_std > balance_threshold * workers.load_avg; .expect("worker not found");
// Determine alpha based on mode // Update worker state
let alpha = if balance_mode { 0.7 } else { 0.3 }; worker.data.request_active_slots += 1;
let gamma = 0.1; // example tuning param worker.data.kv_active_blocks += selection.required_blocks - selection.overlap_blocks as u64;
// Compute each worker's score // Emit event
let mut best_index = None; if let Err(e) = event_tx.send(KVHitRateEvent {
let mut best_cost = f64::INFINITY; worker_id: selection.worker_id,
// [FIXME] REMOVE ONLY FOR TESTING isl_blocks: selection.required_blocks as usize,
if workers.endpoints.is_empty() { overlap_blocks: selection.overlap_blocks,
return Err(KvSchedulerError::NoEndpoints); }) {
tracing::warn!("Failed to send KV hit rate event: {:?}", e);
} }
for (i, w) in workers.endpoints.iter().enumerate() { selection.worker_id
// Exclude workers that are at capacity }
if w.data.request_active_slots >= w.data.request_total_slots
|| w.data.kv_active_blocks >= w.data.kv_total_blocks // Default implementation matching the Python _cost_function
{ #[derive(Default)]
continue; pub struct DefaultWorkerSelector;
impl WorkerSelector for DefaultWorkerSelector {
fn select_worker(
&self,
workers: &ProcessedEndpoints,
request: &SchedulingRequest,
block_size: usize,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
assert!(request.isl_tokens > 0);
let mut worker_scores = HashMap::new();
let mut max_active = 0.0;
// Calculate worker scores and find max waiting requests
for (worker_id, ep) in workers.endpoints.iter() {
// Calculate score similar to Python version
if let Some(score) = request.overlap.scores.get(worker_id) {
let score = *score as f64 * block_size as f64 / request.isl_tokens as f64;
worker_scores.insert(worker_id, score);
}
// Track max waiting requests
max_active = f64::max(max_active, ep.data.request_active_slots as f64);
} }
let kv_load_ratio = w.data.kv_active_blocks as f64 / w.data.kv_total_blocks as f64; if max_active == 0.0 {
let load_deviation = kv_load_ratio - workers.load_avg; return Err(KvSchedulerError::NoEndpoints);
}
let worker_id = w.worker_id(); // make immutable
let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x); let worker_scores = worker_scores;
let overlap_score = overlap_score as usize * kv_block_size; let max_active = max_active;
let new_tokens = request.isl_tokens.saturating_sub(overlap_score); // Calculate logits for each worker
let normalized_new_tokens = new_tokens as f64 / request.isl_tokens as f64; let mut best_logit = f64::NEG_INFINITY;
let mut best_workers = Vec::new();
let request_load_ratio = for (worker_id, ep) in workers.endpoints.iter() {
w.data.request_active_slots as f64 / w.data.request_total_slots as f64; let worker_id = *worker_id;
// cost = alpha * load_deviation + (1 - alpha)*normalized_new_tokens + gamma * request_load_ratio // Get score or default to 0.0
let cost = alpha * load_deviation let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0);
+ (1.0 - alpha) * normalized_new_tokens
+ gamma * request_load_ratio;
tracing::debug!("worker: {}; load_deviation: {}; normalized new blocks: {}; request_load_ratio: {} cost: {}", // Calculate normalized metrics
assert!(ep.data.kv_total_blocks > 0);
let gpu_cache_usage = ep.data.kv_active_blocks as f64 / ep.data.kv_total_blocks as f64;
let normalized_active = if max_active > 0.0 {
ep.data.request_active_slots as f64 / max_active
} else {
0.0
};
// Calculate logit using same formula as Python
let logit = 2.0 * score - gpu_cache_usage - normalized_active;
tracing::info!(
"Formula for {}: {:.3} = 2.0 * {:.3} - {:.3} - {:.3}",
worker_id, worker_id,
load_deviation, logit,
normalized_new_tokens, score,
request_load_ratio, gpu_cache_usage,
cost normalized_active
); );
if cost < best_cost { // Track best workers
best_cost = cost; match logit.partial_cmp(&best_logit) {
best_index = Some(i); Some(std::cmp::Ordering::Greater) => {
best_logit = logit;
best_workers.clear();
best_workers.push(worker_id);
}
Some(std::cmp::Ordering::Equal) => {
best_workers.push(worker_id);
}
_ => {}
}
} }
}
if let Some(best_index) = best_index { // Return early if no valid workers found
let total_blocks = min(request.isl_tokens / kv_block_size, 1); if best_workers.is_empty() || best_logit == 0.0 {
return Err(KvSchedulerError::NoEndpoints);
workers.endpoints[best_index].data.request_active_slots += 1;
workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64;
// Optimization - pass this to a channel for emitting events, async task, etc. to avoid blocking the scheduler
let best_worker_id = workers.endpoints[best_index].worker_id();
let isl_blocks = request.isl_tokens / kv_block_size;
let overlap_blocks = request
.overlap
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0);
if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: best_worker_id,
isl_blocks,
overlap_blocks: overlap_blocks as usize,
}) {
tracing::warn!("Failed to send KV hit rate event: {:?}", e);
} }
}
match best_index { let worker_id = if best_workers.len() == 1 {
Some(i) => { best_workers[0]
tracing::info!( } else {
"selected worker: {}; cost: {}", // Randomly select from best workers
workers.endpoints[i].worker_id(), let mut rng = rand::rng();
best_cost best_workers[rng.random_range(0..best_workers.len())]
); };
Ok(workers.endpoints[i].worker_id())
} // Log selection metrics
None => { tracing::info!("Selected worker: {}, logit: {:.3}", worker_id, best_logit);
tracing::debug!("all workers busy");
Err(KvSchedulerError::AllWorkersBusy) let total_blocks = std::cmp::min(request.isl_tokens / block_size, 1) as u64;
} let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;
Ok(WorkerSelectionResult {
worker_id,
required_blocks: total_blocks,
overlap_blocks,
})
} }
} }
...@@ -16,12 +16,13 @@ ...@@ -16,12 +16,13 @@
//! Scoring functions for the KV router. //! Scoring functions for the KV router.
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::kv_router::scheduler::Endpoint; use crate::kv_router::scheduler::Endpoint;
#[derive(Debug, Default, Serialize, Deserialize, Clone)] #[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct ProcessedEndpoints { pub struct ProcessedEndpoints {
pub endpoints: Vec<Endpoint>, pub endpoints: HashMap<i64, Endpoint>,
pub load_avg: f64, pub load_avg: f64,
pub load_std: f64, pub load_std: f64,
} }
...@@ -41,6 +42,8 @@ impl ProcessedEndpoints { ...@@ -41,6 +42,8 @@ impl ProcessedEndpoints {
/ load_values.len() as f64; / load_values.len() as f64;
let load_std = variance.sqrt(); let load_std = variance.sqrt();
let endpoints = endpoints.into_iter().map(|e| (e.worker_id(), e)).collect();
ProcessedEndpoints { ProcessedEndpoints {
endpoints, endpoints,
load_avg, load_avg,
......
...@@ -105,6 +105,13 @@ impl Tokens { ...@@ -105,6 +105,13 @@ impl Tokens {
pub fn into_sequence(self, block_size: usize) -> TokenSequence { pub fn into_sequence(self, block_size: usize) -> TokenSequence {
TokenSequence::new(self, block_size) TokenSequence::new(self, block_size)
} }
pub fn compute_block_hash(tokens: &[Token], block_size: usize) -> Vec<BlockHash> {
tokens
.par_chunks_exact(block_size)
.map(|chunk| compute_hash(cast_slice(chunk)))
.collect()
}
} }
pub struct PartialTokenBlock { pub struct PartialTokenBlock {
......
...@@ -34,6 +34,7 @@ pub mod discovery; ...@@ -34,6 +34,7 @@ pub mod discovery;
pub mod engine; pub mod engine;
pub mod logging; pub mod logging;
pub mod pipeline; pub mod pipeline;
pub mod prelude;
pub mod protocols; pub mod protocols;
pub mod runnable; pub mod runnable;
pub mod runtime; pub mod runtime;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub use crate::traits::*;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment