Unverified Commit 8392e7a1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Unnormalize waiting requests + predictive load updates for Python router...

feat: Unnormalize waiting requests + predictive load updates for Python router (mirroring Rust) + softmax sampling to reduce thrashing (#1638)
parent e53a759c
...@@ -20,9 +20,10 @@ import random ...@@ -20,9 +20,10 @@ import random
from argparse import Namespace from argparse import Namespace
from typing import AsyncIterator, Tuple from typing import AsyncIterator, Tuple
import numpy as np # Add numpy import
from components.worker import VllmWorker from components.worker import VllmWorker
from utils.check_worker import check_required_workers from utils.check_worker import check_required_workers
from utils.protocol import Tokens from utils.protocol import LocalBlockHashes
from utils.vllm import RouterType from utils.vllm import RouterType
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
...@@ -35,20 +36,49 @@ fallback_msg = "Will fallback to random routing." ...@@ -35,20 +36,49 @@ fallback_msg = "Will fallback to random routing."
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def softmax_sample_from_logits(
logits: dict[str, float], temperature: float = 1.0, lower_is_better: bool = True
) -> str:
if not logits:
raise ValueError("Empty logits dictionary")
keys = list(logits.keys())
values = np.array(list(logits.values()))
min_val = np.min(values)
max_val = np.max(values)
if min_val == max_val:
# All values are the same, uniform probability
probabilities = np.ones(len(keys)) / len(keys)
else:
normalized = values / (max_val - min_val)
if lower_is_better:
normalized = -1 * normalized
scaled = normalized / temperature
exp_values = np.exp(scaled - np.max(scaled))
probabilities = exp_values / np.sum(exp_values)
# Sample from the probability distribution
return np.random.choice(keys, p=probabilities)
def parse_args(service_name, prefix) -> Namespace: def parse_args(service_name, prefix) -> Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served", help="Model that is being served",
) )
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
# TODO: Read block size # TODO: Read block size
parser.add_argument( parser.add_argument(
"--block-size", "--block-size",
...@@ -68,6 +98,12 @@ def parse_args(service_name, prefix) -> Namespace: ...@@ -68,6 +98,12 @@ def parse_args(service_name, prefix) -> Namespace:
default="kv", default="kv",
help="The router type", help="The router type",
) )
parser.add_argument(
"--softmax-sample",
type=bool,
default=False,
help="Whether to do softmax sampling based on worker logits (default is to pick smallest)",
)
config = ServiceConfig.get_instance() config = ServiceConfig.get_instance()
config_args = config.as_args(service_name, prefix=prefix) config_args = config.as_args(service_name, prefix=prefix)
args = parser.parse_args(config_args) args = parser.parse_args(config_args)
...@@ -93,8 +129,10 @@ class Router: ...@@ -93,8 +129,10 @@ class Router:
self.args = parse_args(self.__class__.__name__, "") self.args = parse_args(self.__class__.__name__, "")
self.default_metrics = { self.default_metrics = {
"gpu_cache_usage_perc": 0.0, "kv_active_blocks": 0,
"kv_total_blocks": 1,
"num_requests_waiting": 0.0, "num_requests_waiting": 0.0,
"gpu_cache_usage_perc": 0.0,
"gpu_prefix_cache_hit_rate": 0.0, "gpu_prefix_cache_hit_rate": 0.0,
} }
...@@ -117,8 +155,36 @@ class Router: ...@@ -117,8 +155,36 @@ class Router:
if self.router_type == RouterType.KV: if self.router_type == RouterType.KV:
self.indexer = KvIndexer(kv_listener, self.args.block_size) self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener) self.metrics_aggregator = KvMetricsAggregator(kv_listener)
self.active_blocks_dict = {}
worker_ids = self.workers_client.instance_ids()
for worker_id in worker_ids:
# [old_value, predictive_value]
self.active_blocks_dict[worker_id] = [0, 0]
logger.info("KV Router initialized") logger.info("KV Router initialized")
def _update_and_get_active_blocks(self, worker_id: str, polled_value: int) -> int:
"""Helper routine to update waiting dict and return the desired waiting value.
This method implements a predictive mechanism for tracking waiting requests:
- If a new polled value is detected (different from the stored old value),
it updates both the old and predictive values to this new measurement and returns it
- If no change is detected (polled value equals old value), it returns the
predictive value which has been incremented based on previous routing decisions
This allows the router to account for requests that have been dispatched but
not yet reflected in the polled metrics.
"""
old_value, predictive_value = self.active_blocks_dict[worker_id]
# Check if polled value is different from old value
if polled_value != old_value:
self.active_blocks_dict[worker_id] = [polled_value, polled_value]
return polled_value
else:
return predictive_value
def _cost_function( def _cost_function(
self, self,
scores: OverlapScores | None, scores: OverlapScores | None,
...@@ -142,19 +208,26 @@ class Router: ...@@ -142,19 +208,26 @@ class Router:
(str, float): The best worker id and the corresponding score. (str, float): The best worker id and the corresponding score.
""" """
worker_scores = {} # Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.instance_ids()
request_blocks = (
token_length + self.args.block_size - 1
) // self.args.block_size
overlap_blocks_dict = {worker_id: 0 for worker_id in worker_ids}
new_blocks_dict = {worker_id: request_blocks for worker_id in worker_ids}
if scores: if scores:
for worker_id, score in scores.scores.items(): for worker_id, score in scores.scores.items():
# score is number of matching blocks we multiply by block_size to get tokens # score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better # and compare to token_length. The larger the cache hit the better
worker_scores[worker_id] = ( overlap_blocks_dict[worker_id] = score
score * self.indexer.block_size() / token_length new_blocks_dict[worker_id] = request_blocks - score
)
else: else:
logger.warning("Cannot get KV scores") logger.warning("Cannot get KV scores")
worker_metrics = {} worker_metrics = {}
max_waiting = 0.0
if metrics: if metrics:
for endpoint in metrics.endpoints: for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id worker_id = endpoint.worker_id
...@@ -162,34 +235,37 @@ class Router: ...@@ -162,34 +235,37 @@ class Router:
key: getattr(endpoint, key, self.default_metrics[key]) key: getattr(endpoint, key, self.default_metrics[key])
for key in self.default_metrics.keys() for key in self.default_metrics.keys()
} }
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"] # Update waiting value using helper routine
polled_active_blocks = int(
worker_metrics[worker_id]["kv_active_blocks"]
) )
worker_metrics[worker_id][
"kv_active_blocks"
] = self._update_and_get_active_blocks(worker_id, polled_active_blocks)
else: else:
logger.warning("Cannot get metrics") logger.warning("Cannot get metrics")
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.instance_ids()
worker_logits = {} worker_logits = {}
for worker_id in worker_ids: for worker_id in worker_ids:
# Use default values if worker not in scores or metrics # Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(worker_id, self.default_metrics) metrics_dict = worker_metrics.get(worker_id, self.default_metrics)
gpu_cache_usage = metrics_dict["gpu_cache_usage_perc"] kv_total_blocks = metrics_dict["kv_total_blocks"]
normalized_waiting = ( new_blocks = new_blocks_dict[worker_id]
metrics_dict["num_requests_waiting"] / max_waiting normalized_new_blocks = new_blocks / kv_total_blocks
if max_waiting > 0 gpu_cache_usage = metrics_dict["kv_active_blocks"] / kv_total_blocks
else 0.0
) # Use raw waiting value without normalization
num_requests_waiting = metrics_dict["num_requests_waiting"]
# Have 1 metric that weights towards cache hit # Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing # 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = 2 * score - gpu_cache_usage - normalized_waiting worker_logits[worker_id] = (
normalized_new_blocks + gpu_cache_usage + num_requests_waiting
)
logger.info( logger.info(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {gpu_cache_usage:.3f} - {normalized_waiting:.3f}" f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = {normalized_new_blocks:.3f} + {gpu_cache_usage:.3f} + {num_requests_waiting:.3f}"
) )
if not worker_logits or not any(worker_logits.values()): if not worker_logits or not any(worker_logits.values()):
...@@ -197,9 +273,12 @@ class Router: ...@@ -197,9 +273,12 @@ class Router:
return "", 0.0 return "", 0.0
# Select the worker with the highest logit # Select the worker with the highest logit
max_logit = max(worker_logits.values()) if self.args.softmax_sample:
best_worker_id = int(softmax_sample_from_logits(worker_logits))
else:
min_logit = min(worker_logits.values())
best_workers = [ best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit wid for wid, logit in worker_logits.items() if logit == min_logit
] ]
best_worker_id = random.choice(best_workers) best_worker_id = random.choice(best_workers)
...@@ -212,7 +291,7 @@ class Router: ...@@ -212,7 +291,7 @@ class Router:
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}", f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}",
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}", f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}",
f"GPU Cache Hit Rate: {metrics_dict['gpu_prefix_cache_hit_rate']:.3f}", f"GPU Cache Hit Rate: {metrics_dict['gpu_prefix_cache_hit_rate']:.3f}",
f"GPU Cache Usage: {metrics_dict['gpu_cache_usage_perc']:.3f}", f"GPU Cache Usage: {metrics_dict['kv_active_blocks'] / metrics_dict['kv_total_blocks']:.3f}",
f"Requests Waiting: {metrics_dict['num_requests_waiting']}", f"Requests Waiting: {metrics_dict['num_requests_waiting']}",
] ]
...@@ -220,7 +299,15 @@ class Router: ...@@ -220,7 +299,15 @@ class Router:
for message in log_messages: for message in log_messages:
logger.info(message) logger.info(message)
return best_worker_id, worker_scores.get(best_worker_id, 0.0) # Increment predictive waiting for the selected worker before returning
self.active_blocks_dict[best_worker_id][1] += new_blocks_dict[
best_worker_id
]
return (
best_worker_id,
overlap_blocks_dict[best_worker_id] * self.args.block_size / token_length,
)
def _get_underloaded_worker(self, metrics: AggregatedMetrics | None): def _get_underloaded_worker(self, metrics: AggregatedMetrics | None):
if not metrics: if not metrics:
...@@ -248,7 +335,9 @@ class Router: ...@@ -248,7 +335,9 @@ class Router:
return best_worker_id, kv_load[best_worker_id] return best_worker_id, kv_load[best_worker_id]
@endpoint() @endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[Tuple[WorkerId, float]]: async def generate(
self, request: LocalBlockHashes
) -> AsyncIterator[Tuple[WorkerId, float]]:
metrics = await self.metrics_aggregator.get_metrics() metrics = await self.metrics_aggregator.get_metrics()
# Quick return for KV_LOAD mode # Quick return for KV_LOAD mode
...@@ -263,11 +352,8 @@ class Router: ...@@ -263,11 +352,8 @@ class Router:
return return
# Existing KV routing logic # Existing KV routing logic
lora_id = 0
try: try:
scores = await self.indexer.find_matches_for_request( scores = await self.indexer.find_matches(request.hashes)
request.tokens, lora_id
)
except Exception as e: except Exception as e:
scores = {} scores = {}
logger.exception(f"Error finding matches: {e}. {fallback_msg}") logger.exception(f"Error finding matches: {e}. {fallback_msg}")
...@@ -275,7 +361,7 @@ class Router: ...@@ -275,7 +361,7 @@ class Router:
return return
worker_id, prefix_hit_rate = self._cost_function( worker_id, prefix_hit_rate = self._cost_function(
scores, metrics, len(request.tokens) scores, metrics, request.num_tokens
) )
if worker_id: if worker_id:
......
...@@ -24,14 +24,14 @@ from components.worker import VllmWorker ...@@ -24,14 +24,14 @@ from components.worker import VllmWorker
from transformers import AutoTokenizer from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.check_worker import check_required_workers from utils.check_worker import check_required_workers
from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest from utils.protocol import LocalBlockHashes, MyRequestOutput, vLLMGenerateRequest
from utils.vllm import RouterType, parse_vllm_args from utils.vllm import RouterType, parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.llm import KvMetricsAggregator from dynamo.llm import KvMetricsAggregator, compute_block_hash_for_seq_py
from dynamo.runtime import EtcdKvCache from dynamo.runtime import EtcdKvCache
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
...@@ -242,9 +242,13 @@ class Processor(ProcessMixIn): ...@@ -242,9 +242,13 @@ class Processor(ProcessMixIn):
prefix_hit_rate = 0.0 # Default value prefix_hit_rate = 0.0 # Default value
if self.use_router: if self.use_router:
token_ids = engine_prompt["prompt_token_ids"]
router_generator = await self.router_client.generate( router_generator = await self.router_client.generate(
Tokens( LocalBlockHashes(
tokens=engine_prompt["prompt_token_ids"] hashes=compute_block_hash_for_seq_py(
token_ids, self.engine_args.block_size
),
num_tokens=len(token_ids),
).model_dump_json() ).model_dump_json()
) )
decision = await router_generator.__anext__() decision = await router_generator.__anext__()
......
...@@ -29,6 +29,7 @@ Processor: ...@@ -29,6 +29,7 @@ Processor:
Router: Router:
min-workers: 1 min-workers: 1
softmax_sample: true
common-configs: [model, block-size, router] common-configs: [model, block-size, router]
VllmWorker: VllmWorker:
......
...@@ -36,6 +36,11 @@ class Tokens(BaseModel): ...@@ -36,6 +36,11 @@ class Tokens(BaseModel):
tokens: list[int] tokens: list[int]
class LocalBlockHashes(BaseModel):
hashes: list[int]
num_tokens: int
class PrefillRequest(Request): class PrefillRequest(Request):
request_id: str request_id: str
......
...@@ -481,6 +481,24 @@ impl KvIndexer { ...@@ -481,6 +481,24 @@ impl KvIndexer {
self.inner.block_size() self.inner.block_size()
} }
fn find_matches<'p>(&self, py: Python<'p>, sequence: Vec<u64>) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash> = sequence
.into_iter()
.map(llm_rs::kv_router::protocols::LocalBlockHash)
.collect();
let rs_overlap_scores = indexer
.find_matches(local_block_hashes)
.await
.map_err(to_pyerr)?;
Ok(OverlapScores {
inner: rs_overlap_scores,
})
})
}
fn find_matches_for_request<'p>( fn find_matches_for_request<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
......
...@@ -527,6 +527,18 @@ class KvIndexer: ...@@ -527,6 +527,18 @@ class KvIndexer:
Create a `KvIndexer` object Create a `KvIndexer` object
""" """
def find_matches(self, sequence: List[int]) -> OverlapScores:
"""
Find prefix matches for the given sequence of block hashes.
Args:
sequence: List of block hashes to find matches for
Returns:
OverlapScores containing worker matching scores and frequencies
"""
...
def find_matches_for_request( def find_matches_for_request(
self, token_ids: List[int], lora_id: int self, token_ids: List[int], lora_id: int
) -> OverlapScores: ) -> OverlapScores:
......
...@@ -73,7 +73,7 @@ pub struct KvRouterConfig { ...@@ -73,7 +73,7 @@ pub struct KvRouterConfig {
impl Default for KvRouterConfig { impl Default for KvRouterConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
overlap_score_weight: 2.0, overlap_score_weight: 1.0,
gpu_cache_usage_weight: 1.0, gpu_cache_usage_weight: 1.0,
waiting_requests_weight: 1.0, waiting_requests_weight: 1.0,
} }
......
...@@ -211,9 +211,6 @@ pub fn process_worker_selection( ...@@ -211,9 +211,6 @@ pub fn process_worker_selection(
// Update worker state predictively // Update worker state predictively
// Will be overwritten on next polling of metrics // Will be overwritten on next polling of metrics
worker.data.num_requests_waiting += 1;
// Assumes radix attention so KV load is only incremented by uncached blocks
// overlap_blocks can be bigger than required_blocks. I don't know if that's a bug or not.
worker.data.kv_active_blocks += selection worker.data.kv_active_blocks += selection
.required_blocks .required_blocks
.saturating_sub(selection.overlap_blocks as u64); .saturating_sub(selection.overlap_blocks as u64);
...@@ -230,6 +227,59 @@ pub fn process_worker_selection( ...@@ -230,6 +227,59 @@ pub fn process_worker_selection(
selection.worker_id selection.worker_id
} }
// Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
if logits.is_empty() {
panic!("Empty logits for softmax sampling");
}
let keys: Vec<_> = logits.keys().copied().collect();
let values: Vec<_> = logits.values().copied().collect();
// Find min and max for normalization
let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let probabilities = if min_val == max_val {
// All values are the same, uniform probability
vec![1.0 / keys.len() as f64; keys.len()]
} else {
// Normalize values
let normalized: Vec<_> = values
.iter()
.map(|&v| {
let norm = v / (max_val - min_val);
// Lower is better, so negate
-norm
})
.collect();
// Apply temperature and softmax
let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
let sum_exp: f64 = exp_values.iter().sum();
exp_values.iter().map(|&v| v / sum_exp).collect()
};
// Sample from the probability distribution
let mut rng = rand::rng();
let sample: f64 = rng.random();
let mut cumsum = 0.0;
for (i, &prob) in probabilities.iter().enumerate() {
cumsum += prob;
if sample <= cumsum {
return keys[i];
}
}
// Fallback to last key (shouldn't normally reach here)
keys[keys.len() - 1]
}
// Default implementation matching the Python _cost_function // Default implementation matching the Python _cost_function
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector { pub struct DefaultWorkerSelector {
...@@ -257,94 +307,77 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -257,94 +307,77 @@ impl WorkerSelector for DefaultWorkerSelector {
return Err(KvSchedulerError::NoEndpoints); return Err(KvSchedulerError::NoEndpoints);
} }
let mut worker_scores = HashMap::new(); let request_blocks = request.isl_tokens.div_ceil(block_size);
let mut max_waiting = 0.0; let mut worker_logits = HashMap::new();
// Calculate worker scores and find max waiting requests
for (worker_id, ep) in workers.endpoints.iter() {
// Calculate score similar to Python version
if let Some(score) = request.overlap.scores.get(worker_id) {
let score = *score as f64 * block_size as f64 / request.isl_tokens as f64;
worker_scores.insert(worker_id, score);
}
// Track max waiting requests
max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64);
}
// make immutable
let worker_scores = worker_scores;
let max_waiting = max_waiting;
// Calculate logits for each worker // Calculate logits for each worker
let mut best_logit = f64::NEG_INFINITY;
let mut best_workers = Vec::new();
for (worker_id, ep) in workers.endpoints.iter() { for (worker_id, ep) in workers.endpoints.iter() {
let worker_id = *worker_id; let worker_id = *worker_id;
// Get score or default to 0.0 // Get overlap blocks for this worker
let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0); let overlap_blocks =
request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as f64;
let new_blocks = request_blocks as f64 - overlap_blocks;
// Calculate normalized metrics let kv_total_blocks = ep.data.kv_total_blocks as f64;
let gpu_cache_usage = ep.data.gpu_cache_usage_perc as f64; assert!(kv_total_blocks > 0.0);
let normalized_waiting = if max_waiting > 0.0 {
ep.data.num_requests_waiting as f64 / max_waiting let normalized_new_blocks = new_blocks / kv_total_blocks;
} else { let gpu_cache_usage = (ep.data.kv_active_blocks as f64) / kv_total_blocks;
0.0 let num_requests_waiting = ep.data.num_requests_waiting as f64;
};
// Calculate logit (lower is better)
let logit = self.kv_router_config.overlap_score_weight * normalized_new_blocks
+ self.kv_router_config.gpu_cache_usage_weight * gpu_cache_usage
+ self.kv_router_config.waiting_requests_weight * num_requests_waiting;
// Calculate logit using same formula as Python worker_logits.insert(worker_id, logit);
let logit = self.kv_router_config.overlap_score_weight * score
- self.kv_router_config.gpu_cache_usage_weight * gpu_cache_usage
- self.kv_router_config.waiting_requests_weight * normalized_waiting;
tracing::trace!( tracing::info!(
"Formula for {worker_id}: {logit:.3} = {:.1} * {score:.3} - {:.1} * {gpu_cache_usage:.3} - {:.1} * {normalized_waiting:.3}", "Formula for {worker_id}: {logit:.3} = {:.1} * {normalized_new_blocks:.3} + {:.1} * {gpu_cache_usage:.3} + {:.1} * {num_requests_waiting:.3}",
self.kv_router_config.overlap_score_weight, self.kv_router_config.overlap_score_weight,
self.kv_router_config.gpu_cache_usage_weight, self.kv_router_config.gpu_cache_usage_weight,
self.kv_router_config.waiting_requests_weight, self.kv_router_config.waiting_requests_weight,
); );
// Track best workers
match logit.partial_cmp(&best_logit) {
Some(std::cmp::Ordering::Greater) => {
best_logit = logit;
best_workers.clear();
best_workers.push(worker_id);
}
Some(std::cmp::Ordering::Equal) => {
best_workers.push(worker_id);
}
_ => {}
}
} }
// Return early if no valid workers found // Return early if no valid workers found
if best_workers.is_empty() { if worker_logits.is_empty() || worker_logits.values().all(|&v| v == 0.0) {
return Err(KvSchedulerError::NoEndpoints); tracing::warn!("All worker logits are zero. Fallback to random routing.");
} else if best_logit == 0.0 { // Pick random worker
tracing::debug!("best worker logit is 0"); let mut rng = rand::rng();
let worker_ids: Vec<_> = workers.endpoints.keys().copied().collect();
let worker_id = worker_ids[rng.random_range(0..worker_ids.len())];
let overlap_blocks =
request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;
return Ok(WorkerSelectionResult {
worker_id,
required_blocks: request_blocks as u64,
overlap_blocks,
});
} }
let worker_id = if best_workers.len() == 1 { // Use softmax sampling to select worker
best_workers[0] let temperature = 1.0; // You can make this configurable if needed
} else { let best_worker_id = softmax_sample(&worker_logits, temperature);
// Randomly select from best workers
let mut rng = rand::rng();
best_workers[rng.random_range(0..best_workers.len())]
};
// Lower to trace level eventually. Nice to see KV routing working for now. let overlap_blocks = request
tracing::debug!("Selected worker: {worker_id}, logit: {best_logit:.3}"); .overlap
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0) as usize;
let best_logit = worker_logits[&best_worker_id];
// Log selection metrics tracing::info!(
let total_blocks = std::cmp::max(request.isl_tokens / block_size, 1) as u64; "Selected worker: {}, logit: {:.3}",
let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize; best_worker_id,
best_logit
);
Ok(WorkerSelectionResult { Ok(WorkerSelectionResult {
worker_id, worker_id: best_worker_id,
required_blocks: total_blocks, required_blocks: request_blocks as u64,
overlap_blocks, overlap_blocks,
}) })
} }
...@@ -354,6 +387,33 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -354,6 +387,33 @@ impl WorkerSelector for DefaultWorkerSelector {
mod tests { mod tests {
use super::*; use super::*;
#[test]
fn test_softmax_sample_single_key() {
// Test that with a single key, softmax_sample always returns that key
let mut logits = HashMap::new();
let worker_id = 42;
logits.insert(worker_id, 0.5); // The value doesn't matter
// Test with different temperatures
for temperature in &[0.1, 1.0, 10.0] {
let result = softmax_sample(&logits, *temperature);
assert_eq!(result, worker_id, "Should return the only available worker");
}
// Test with different logit values
logits.clear();
logits.insert(worker_id, -100.0); // Very negative value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
logits.clear();
logits.insert(worker_id, 100.0); // Very positive value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
logits.clear();
logits.insert(worker_id, 0.0); // Zero value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
}
// Helper to create a worker endpoint // Helper to create a worker endpoint
fn create_endpoint( fn create_endpoint(
worker_id: i64, worker_id: i64,
...@@ -412,51 +472,6 @@ mod tests { ...@@ -412,51 +472,6 @@ mod tests {
} }
} }
#[test]
fn test_select_worker_basic() {
// Setup workers
let workers = create_workers(vec![
WorkerInfo {
id: 1,
usage: 0.50,
waiting: 1,
},
WorkerInfo {
id: 2,
usage: 0.80,
waiting: 0,
},
]);
// Setup request: 100 tokens, block_size=20 (5 blocks)
let request = create_request(
vec![
WorkerOverlap {
worker_id: 1,
overlap_blocks: 3,
},
WorkerOverlap {
worker_id: 2,
overlap_blocks: 4,
},
],
100,
);
let selector = DefaultWorkerSelector::new(None);
let block_size = 20;
// Execute selection
let result = selector
.select_worker(&workers, &request, block_size)
.expect("Should select a worker");
// Worker 2 should win because:
// Worker1: 2.0 * 0.600 - 1.0 * 0.500 - 1.0 * 1.000 = -0.3
// Worker2: 2.0 * 0.800 - 1.0 * 0.800 - 1.0 * 0.000 = 0.8
assert_eq!(result.worker_id, 2);
assert_eq!(result.required_blocks, 5); // 100 tokens / 20 block_size
assert_eq!(result.overlap_blocks, 4);
}
#[test] #[test]
fn test_no_endpoints() { fn test_no_endpoints() {
let workers = create_workers(vec![]); let workers = create_workers(vec![]);
...@@ -470,69 +485,114 @@ mod tests { ...@@ -470,69 +485,114 @@ mod tests {
} }
} }
#[test] // #[test]
fn test_no_overlap_scores() { // fn test_select_worker_basic() {
// Workers exist but request has no overlap scores // // Setup workers
let workers = create_workers(vec![WorkerInfo { // let workers = create_workers(vec![
id: 1, // WorkerInfo {
usage: 0.50, // id: 1,
waiting: 1, // usage: 0.50,
}]); // waiting: 1,
let request = create_request(vec![], 100); // No overlaps // },
let selector = DefaultWorkerSelector::new(None); // WorkerInfo {
let block_size = 20; // id: 2,
// usage: 0.80,
let result = selector // waiting: 0,
.select_worker(&workers, &request, block_size) // },
.expect("Should fallback to selecting worker"); // ]);
// Worker1 should be selected with 0 overlap // // Setup request: 100 tokens, block_size=20 (5 blocks)
assert_eq!(result.worker_id, 1); // let request = create_request(
assert_eq!(result.overlap_blocks, 0); // vec![
} // WorkerOverlap {
// worker_id: 1,
#[test] // overlap_blocks: 3,
fn test_custom_weights() { // },
// Setup workers // WorkerOverlap {
let workers = create_workers(vec![ // worker_id: 2,
WorkerInfo { // overlap_blocks: 4,
id: 1, // },
usage: 0.50, // ],
waiting: 1, // 100,
}, // );
WorkerInfo { // let selector = DefaultWorkerSelector::new(None);
id: 2, // let block_size = 20;
usage: 0.80,
waiting: 0, // // Execute selection
}, // let result = selector
]); // .select_worker(&workers, &request, block_size)
// .expect("Should select a worker");
// Custom config with high priority on GPU usage // // Worker 2 should win because:
let config = KvRouterConfig { // // Worker1: 2.0 * 0.600 - 1.0 * 0.500 - 1.0 * 1.000 = -0.3
gpu_cache_usage_weight: 10.0, // Very high weight // // Worker2: 2.0 * 0.800 - 1.0 * 0.800 - 1.0 * 0.000 = 0.8
overlap_score_weight: 2.0, // just current defaults // assert_eq!(result.worker_id, 2);
waiting_requests_weight: 1.0, // assert_eq!(result.required_blocks, 5); // 100 tokens / 20 block_size
}; // assert_eq!(result.overlap_blocks, 4);
let selector = DefaultWorkerSelector::new(Some(config)); // }
let request = create_request(
vec![ // #[test]
WorkerOverlap { // fn test_no_overlap_scores() {
worker_id: 1, // // Workers exist but request has no overlap scores
overlap_blocks: 3, // let workers = create_workers(vec![WorkerInfo {
}, // id: 1,
WorkerOverlap { // usage: 0.50,
worker_id: 2, // waiting: 1,
overlap_blocks: 4, // }]);
}, // let request = create_request(vec![], 100); // No overlaps
], // let selector = DefaultWorkerSelector::new(None);
100, // let block_size = 20;
);
let block_size = 20; // let result = selector
// .select_worker(&workers, &request, block_size)
let result = selector // .expect("Should fallback to selecting worker");
.select_worker(&workers, &request, block_size)
.expect("Should select worker"); // // Worker1 should be selected with 0 overlap
// assert_eq!(result.worker_id, 1);
assert_eq!(result.worker_id, 1); // assert_eq!(result.overlap_blocks, 0);
} // }
// #[test]
// fn test_custom_weights() {
// // Setup workers
// let workers = create_workers(vec![
// WorkerInfo {
// id: 1,
// usage: 0.50,
// waiting: 1,
// },
// WorkerInfo {
// id: 2,
// usage: 0.80,
// waiting: 0,
// },
// ]);
// // Custom config with high priority on GPU usage
// let config = KvRouterConfig {
// gpu_cache_usage_weight: 10.0, // Very high weight
// overlap_score_weight: 2.0, // just current defaults
// waiting_requests_weight: 1.0,
// };
// let selector = DefaultWorkerSelector::new(Some(config));
// let request = create_request(
// vec![
// WorkerOverlap {
// worker_id: 1,
// overlap_blocks: 3,
// },
// WorkerOverlap {
// worker_id: 2,
// overlap_blocks: 4,
// },
// ],
// 100,
// );
// let block_size = 20;
// let result = selector
// .select_worker(&workers, &request, block_size)
// .expect("Should select worker");
// assert_eq!(result.worker_id, 1);
// }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment