Unverified Commit 8392e7a1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Unnormalize waiting requests + predictive load updates for Python router...

feat: Unnormalize waiting requests + predictive load updates for Python router (mirroring Rust) + softmax sampling to reduce thrashing (#1638)
parent e53a759c
......@@ -20,9 +20,10 @@ import random
from argparse import Namespace
from typing import AsyncIterator, Tuple
import numpy as np # Add numpy import
from components.worker import VllmWorker
from utils.check_worker import check_required_workers
from utils.protocol import Tokens
from utils.protocol import LocalBlockHashes
from utils.vllm import RouterType
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
......@@ -35,20 +36,49 @@ fallback_msg = "Will fallback to random routing."
logger = logging.getLogger(__name__)
def softmax_sample_from_logits(
logits: dict[str, float], temperature: float = 1.0, lower_is_better: bool = True
) -> str:
if not logits:
raise ValueError("Empty logits dictionary")
keys = list(logits.keys())
values = np.array(list(logits.values()))
min_val = np.min(values)
max_val = np.max(values)
if min_val == max_val:
# All values are the same, uniform probability
probabilities = np.ones(len(keys)) / len(keys)
else:
normalized = values / (max_val - min_val)
if lower_is_better:
normalized = -1 * normalized
scaled = normalized / temperature
exp_values = np.exp(scaled - np.max(scaled))
probabilities = exp_values / np.sum(exp_values)
# Sample from the probability distribution
return np.random.choice(keys, p=probabilities)
def parse_args(service_name, prefix) -> Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
parser.add_argument(
"--model",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
# TODO: Read block size
parser.add_argument(
"--block-size",
......@@ -68,6 +98,12 @@ def parse_args(service_name, prefix) -> Namespace:
default="kv",
help="The router type",
)
parser.add_argument(
"--softmax-sample",
type=bool,
default=False,
help="Whether to do softmax sampling based on worker logits (default is to pick smallest)",
)
config = ServiceConfig.get_instance()
config_args = config.as_args(service_name, prefix=prefix)
args = parser.parse_args(config_args)
......@@ -93,8 +129,10 @@ class Router:
self.args = parse_args(self.__class__.__name__, "")
self.default_metrics = {
"gpu_cache_usage_perc": 0.0,
"kv_active_blocks": 0,
"kv_total_blocks": 1,
"num_requests_waiting": 0.0,
"gpu_cache_usage_perc": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
}
......@@ -117,8 +155,36 @@ class Router:
if self.router_type == RouterType.KV:
self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
self.active_blocks_dict = {}
worker_ids = self.workers_client.instance_ids()
for worker_id in worker_ids:
# [old_value, predictive_value]
self.active_blocks_dict[worker_id] = [0, 0]
logger.info("KV Router initialized")
def _update_and_get_active_blocks(self, worker_id: str, polled_value: int) -> int:
"""Helper routine to update waiting dict and return the desired waiting value.
This method implements a predictive mechanism for tracking waiting requests:
- If a new polled value is detected (different from the stored old value),
it updates both the old and predictive values to this new measurement and returns it
- If no change is detected (polled value equals old value), it returns the
predictive value which has been incremented based on previous routing decisions
This allows the router to account for requests that have been dispatched but
not yet reflected in the polled metrics.
"""
old_value, predictive_value = self.active_blocks_dict[worker_id]
# Check if polled value is different from old value
if polled_value != old_value:
self.active_blocks_dict[worker_id] = [polled_value, polled_value]
return polled_value
else:
return predictive_value
def _cost_function(
self,
scores: OverlapScores | None,
......@@ -142,19 +208,26 @@ class Router:
(str, float): The best worker id and the corresponding score.
"""
worker_scores = {}
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.instance_ids()
request_blocks = (
token_length + self.args.block_size - 1
) // self.args.block_size
overlap_blocks_dict = {worker_id: 0 for worker_id in worker_ids}
new_blocks_dict = {worker_id: request_blocks for worker_id in worker_ids}
if scores:
for worker_id, score in scores.scores.items():
# score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better
worker_scores[worker_id] = (
score * self.indexer.block_size() / token_length
)
overlap_blocks_dict[worker_id] = score
new_blocks_dict[worker_id] = request_blocks - score
else:
logger.warning("Cannot get KV scores")
worker_metrics = {}
max_waiting = 0.0
if metrics:
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
......@@ -162,34 +235,37 @@ class Router:
key: getattr(endpoint, key, self.default_metrics[key])
for key in self.default_metrics.keys()
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
# Update waiting value using helper routine
polled_active_blocks = int(
worker_metrics[worker_id]["kv_active_blocks"]
)
worker_metrics[worker_id][
"kv_active_blocks"
] = self._update_and_get_active_blocks(worker_id, polled_active_blocks)
else:
logger.warning("Cannot get metrics")
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.instance_ids()
worker_logits = {}
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(worker_id, self.default_metrics)
gpu_cache_usage = metrics_dict["gpu_cache_usage_perc"]
kv_total_blocks = metrics_dict["kv_total_blocks"]
normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
if max_waiting > 0
else 0.0
)
new_blocks = new_blocks_dict[worker_id]
normalized_new_blocks = new_blocks / kv_total_blocks
gpu_cache_usage = metrics_dict["kv_active_blocks"] / kv_total_blocks
# Use raw waiting value without normalization
num_requests_waiting = metrics_dict["num_requests_waiting"]
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = 2 * score - gpu_cache_usage - normalized_waiting
worker_logits[worker_id] = (
normalized_new_blocks + gpu_cache_usage + num_requests_waiting
)
logger.info(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {gpu_cache_usage:.3f} - {normalized_waiting:.3f}"
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = {normalized_new_blocks:.3f} + {gpu_cache_usage:.3f} + {num_requests_waiting:.3f}"
)
if not worker_logits or not any(worker_logits.values()):
......@@ -197,11 +273,14 @@ class Router:
return "", 0.0
# Select the worker with the highest logit
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
if self.args.softmax_sample:
best_worker_id = int(softmax_sample_from_logits(worker_logits))
else:
min_logit = min(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == min_logit
]
best_worker_id = random.choice(best_workers)
# Log the metrics for the selected worker
if best_worker_id:
......@@ -212,7 +291,7 @@ class Router:
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}",
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}",
f"GPU Cache Hit Rate: {metrics_dict['gpu_prefix_cache_hit_rate']:.3f}",
f"GPU Cache Usage: {metrics_dict['gpu_cache_usage_perc']:.3f}",
f"GPU Cache Usage: {metrics_dict['kv_active_blocks'] / metrics_dict['kv_total_blocks']:.3f}",
f"Requests Waiting: {metrics_dict['num_requests_waiting']}",
]
......@@ -220,7 +299,15 @@ class Router:
for message in log_messages:
logger.info(message)
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
# Increment predictive waiting for the selected worker before returning
self.active_blocks_dict[best_worker_id][1] += new_blocks_dict[
best_worker_id
]
return (
best_worker_id,
overlap_blocks_dict[best_worker_id] * self.args.block_size / token_length,
)
def _get_underloaded_worker(self, metrics: AggregatedMetrics | None):
if not metrics:
......@@ -248,7 +335,9 @@ class Router:
return best_worker_id, kv_load[best_worker_id]
@endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[Tuple[WorkerId, float]]:
async def generate(
self, request: LocalBlockHashes
) -> AsyncIterator[Tuple[WorkerId, float]]:
metrics = await self.metrics_aggregator.get_metrics()
# Quick return for KV_LOAD mode
......@@ -263,11 +352,8 @@ class Router:
return
# Existing KV routing logic
lora_id = 0
try:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
scores = await self.indexer.find_matches(request.hashes)
except Exception as e:
scores = {}
logger.exception(f"Error finding matches: {e}. {fallback_msg}")
......@@ -275,7 +361,7 @@ class Router:
return
worker_id, prefix_hit_rate = self._cost_function(
scores, metrics, len(request.tokens)
scores, metrics, request.num_tokens
)
if worker_id:
......
......@@ -24,14 +24,14 @@ from components.worker import VllmWorker
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.check_worker import check_required_workers
from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from utils.protocol import LocalBlockHashes, MyRequestOutput, vLLMGenerateRequest
from utils.vllm import RouterType, parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.llm import KvMetricsAggregator
from dynamo.llm import KvMetricsAggregator, compute_block_hash_for_seq_py
from dynamo.runtime import EtcdKvCache
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
......@@ -242,9 +242,13 @@ class Processor(ProcessMixIn):
prefix_hit_rate = 0.0 # Default value
if self.use_router:
token_ids = engine_prompt["prompt_token_ids"]
router_generator = await self.router_client.generate(
Tokens(
tokens=engine_prompt["prompt_token_ids"]
LocalBlockHashes(
hashes=compute_block_hash_for_seq_py(
token_ids, self.engine_args.block_size
),
num_tokens=len(token_ids),
).model_dump_json()
)
decision = await router_generator.__anext__()
......
......@@ -29,6 +29,7 @@ Processor:
Router:
min-workers: 1
softmax_sample: true
common-configs: [model, block-size, router]
VllmWorker:
......
......@@ -36,6 +36,11 @@ class Tokens(BaseModel):
tokens: list[int]
class LocalBlockHashes(BaseModel):
hashes: list[int]
num_tokens: int
class PrefillRequest(Request):
request_id: str
......
......@@ -481,6 +481,24 @@ impl KvIndexer {
self.inner.block_size()
}
fn find_matches<'p>(&self, py: Python<'p>, sequence: Vec<u64>) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash> = sequence
.into_iter()
.map(llm_rs::kv_router::protocols::LocalBlockHash)
.collect();
let rs_overlap_scores = indexer
.find_matches(local_block_hashes)
.await
.map_err(to_pyerr)?;
Ok(OverlapScores {
inner: rs_overlap_scores,
})
})
}
fn find_matches_for_request<'p>(
&self,
py: Python<'p>,
......
......@@ -527,6 +527,18 @@ class KvIndexer:
Create a `KvIndexer` object
"""
def find_matches(self, sequence: List[int]) -> OverlapScores:
"""
Find prefix matches for the given sequence of block hashes.
Args:
sequence: List of block hashes to find matches for
Returns:
OverlapScores containing worker matching scores and frequencies
"""
...
def find_matches_for_request(
self, token_ids: List[int], lora_id: int
) -> OverlapScores:
......
......@@ -73,7 +73,7 @@ pub struct KvRouterConfig {
impl Default for KvRouterConfig {
fn default() -> Self {
Self {
overlap_score_weight: 2.0,
overlap_score_weight: 1.0,
gpu_cache_usage_weight: 1.0,
waiting_requests_weight: 1.0,
}
......
......@@ -211,9 +211,6 @@ pub fn process_worker_selection(
// Update worker state predictively
// Will be overwritten on next polling of metrics
worker.data.num_requests_waiting += 1;
// Assumes radix attention so KV load is only incremented by uncached blocks
// overlap_blocks can be bigger than required_blocks. I don't know if that's a bug or not.
worker.data.kv_active_blocks += selection
.required_blocks
.saturating_sub(selection.overlap_blocks as u64);
......@@ -230,6 +227,59 @@ pub fn process_worker_selection(
selection.worker_id
}
// Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
if logits.is_empty() {
panic!("Empty logits for softmax sampling");
}
let keys: Vec<_> = logits.keys().copied().collect();
let values: Vec<_> = logits.values().copied().collect();
// Find min and max for normalization
let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let probabilities = if min_val == max_val {
// All values are the same, uniform probability
vec![1.0 / keys.len() as f64; keys.len()]
} else {
// Normalize values
let normalized: Vec<_> = values
.iter()
.map(|&v| {
let norm = v / (max_val - min_val);
// Lower is better, so negate
-norm
})
.collect();
// Apply temperature and softmax
let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
let sum_exp: f64 = exp_values.iter().sum();
exp_values.iter().map(|&v| v / sum_exp).collect()
};
// Sample from the probability distribution
let mut rng = rand::rng();
let sample: f64 = rng.random();
let mut cumsum = 0.0;
for (i, &prob) in probabilities.iter().enumerate() {
cumsum += prob;
if sample <= cumsum {
return keys[i];
}
}
// Fallback to last key (shouldn't normally reach here)
keys[keys.len() - 1]
}
// Default implementation matching the Python _cost_function
#[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector {
......@@ -257,94 +307,77 @@ impl WorkerSelector for DefaultWorkerSelector {
return Err(KvSchedulerError::NoEndpoints);
}
let mut worker_scores = HashMap::new();
let mut max_waiting = 0.0;
// Calculate worker scores and find max waiting requests
for (worker_id, ep) in workers.endpoints.iter() {
// Calculate score similar to Python version
if let Some(score) = request.overlap.scores.get(worker_id) {
let score = *score as f64 * block_size as f64 / request.isl_tokens as f64;
worker_scores.insert(worker_id, score);
}
// Track max waiting requests
max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64);
}
// make immutable
let worker_scores = worker_scores;
let max_waiting = max_waiting;
let request_blocks = request.isl_tokens.div_ceil(block_size);
let mut worker_logits = HashMap::new();
// Calculate logits for each worker
let mut best_logit = f64::NEG_INFINITY;
let mut best_workers = Vec::new();
for (worker_id, ep) in workers.endpoints.iter() {
let worker_id = *worker_id;
// Get score or default to 0.0
let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0);
// Get overlap blocks for this worker
let overlap_blocks =
request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as f64;
let new_blocks = request_blocks as f64 - overlap_blocks;
let kv_total_blocks = ep.data.kv_total_blocks as f64;
assert!(kv_total_blocks > 0.0);
// Calculate normalized metrics
let gpu_cache_usage = ep.data.gpu_cache_usage_perc as f64;
let normalized_waiting = if max_waiting > 0.0 {
ep.data.num_requests_waiting as f64 / max_waiting
} else {
0.0
};
let normalized_new_blocks = new_blocks / kv_total_blocks;
let gpu_cache_usage = (ep.data.kv_active_blocks as f64) / kv_total_blocks;
let num_requests_waiting = ep.data.num_requests_waiting as f64;
// Calculate logit using same formula as Python
let logit = self.kv_router_config.overlap_score_weight * score
- self.kv_router_config.gpu_cache_usage_weight * gpu_cache_usage
- self.kv_router_config.waiting_requests_weight * normalized_waiting;
// Calculate logit (lower is better)
let logit = self.kv_router_config.overlap_score_weight * normalized_new_blocks
+ self.kv_router_config.gpu_cache_usage_weight * gpu_cache_usage
+ self.kv_router_config.waiting_requests_weight * num_requests_waiting;
tracing::trace!(
"Formula for {worker_id}: {logit:.3} = {:.1} * {score:.3} - {:.1} * {gpu_cache_usage:.3} - {:.1} * {normalized_waiting:.3}",
worker_logits.insert(worker_id, logit);
tracing::info!(
"Formula for {worker_id}: {logit:.3} = {:.1} * {normalized_new_blocks:.3} + {:.1} * {gpu_cache_usage:.3} + {:.1} * {num_requests_waiting:.3}",
self.kv_router_config.overlap_score_weight,
self.kv_router_config.gpu_cache_usage_weight,
self.kv_router_config.waiting_requests_weight,
);
// Track best workers
match logit.partial_cmp(&best_logit) {
Some(std::cmp::Ordering::Greater) => {
best_logit = logit;
best_workers.clear();
best_workers.push(worker_id);
}
Some(std::cmp::Ordering::Equal) => {
best_workers.push(worker_id);
}
_ => {}
}
}
// Return early if no valid workers found
if best_workers.is_empty() {
return Err(KvSchedulerError::NoEndpoints);
} else if best_logit == 0.0 {
tracing::debug!("best worker logit is 0");
}
let worker_id = if best_workers.len() == 1 {
best_workers[0]
} else {
// Randomly select from best workers
if worker_logits.is_empty() || worker_logits.values().all(|&v| v == 0.0) {
tracing::warn!("All worker logits are zero. Fallback to random routing.");
// Pick random worker
let mut rng = rand::rng();
best_workers[rng.random_range(0..best_workers.len())]
};
// Lower to trace level eventually. Nice to see KV routing working for now.
tracing::debug!("Selected worker: {worker_id}, logit: {best_logit:.3}");
let worker_ids: Vec<_> = workers.endpoints.keys().copied().collect();
let worker_id = worker_ids[rng.random_range(0..worker_ids.len())];
let overlap_blocks =
request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;
return Ok(WorkerSelectionResult {
worker_id,
required_blocks: request_blocks as u64,
overlap_blocks,
});
}
// Log selection metrics
let total_blocks = std::cmp::max(request.isl_tokens / block_size, 1) as u64;
let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;
// Use softmax sampling to select worker
let temperature = 1.0; // You can make this configurable if needed
let best_worker_id = softmax_sample(&worker_logits, temperature);
let overlap_blocks = request
.overlap
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0) as usize;
let best_logit = worker_logits[&best_worker_id];
tracing::info!(
"Selected worker: {}, logit: {:.3}",
best_worker_id,
best_logit
);
Ok(WorkerSelectionResult {
worker_id,
required_blocks: total_blocks,
worker_id: best_worker_id,
required_blocks: request_blocks as u64,
overlap_blocks,
})
}
......@@ -354,6 +387,33 @@ impl WorkerSelector for DefaultWorkerSelector {
mod tests {
use super::*;
#[test]
fn test_softmax_sample_single_key() {
// Test that with a single key, softmax_sample always returns that key
let mut logits = HashMap::new();
let worker_id = 42;
logits.insert(worker_id, 0.5); // The value doesn't matter
// Test with different temperatures
for temperature in &[0.1, 1.0, 10.0] {
let result = softmax_sample(&logits, *temperature);
assert_eq!(result, worker_id, "Should return the only available worker");
}
// Test with different logit values
logits.clear();
logits.insert(worker_id, -100.0); // Very negative value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
logits.clear();
logits.insert(worker_id, 100.0); // Very positive value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
logits.clear();
logits.insert(worker_id, 0.0); // Zero value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
}
// Helper to create a worker endpoint
fn create_endpoint(
worker_id: i64,
......@@ -412,51 +472,6 @@ mod tests {
}
}
#[test]
fn test_select_worker_basic() {
// Setup workers
let workers = create_workers(vec![
WorkerInfo {
id: 1,
usage: 0.50,
waiting: 1,
},
WorkerInfo {
id: 2,
usage: 0.80,
waiting: 0,
},
]);
// Setup request: 100 tokens, block_size=20 (5 blocks)
let request = create_request(
vec![
WorkerOverlap {
worker_id: 1,
overlap_blocks: 3,
},
WorkerOverlap {
worker_id: 2,
overlap_blocks: 4,
},
],
100,
);
let selector = DefaultWorkerSelector::new(None);
let block_size = 20;
// Execute selection
let result = selector
.select_worker(&workers, &request, block_size)
.expect("Should select a worker");
// Worker 2 should win because:
// Worker1: 2.0 * 0.600 - 1.0 * 0.500 - 1.0 * 1.000 = -0.3
// Worker2: 2.0 * 0.800 - 1.0 * 0.800 - 1.0 * 0.000 = 0.8
assert_eq!(result.worker_id, 2);
assert_eq!(result.required_blocks, 5); // 100 tokens / 20 block_size
assert_eq!(result.overlap_blocks, 4);
}
#[test]
fn test_no_endpoints() {
let workers = create_workers(vec![]);
......@@ -470,69 +485,114 @@ mod tests {
}
}
#[test]
fn test_no_overlap_scores() {
// Workers exist but request has no overlap scores
let workers = create_workers(vec![WorkerInfo {
id: 1,
usage: 0.50,
waiting: 1,
}]);
let request = create_request(vec![], 100); // No overlaps
let selector = DefaultWorkerSelector::new(None);
let block_size = 20;
let result = selector
.select_worker(&workers, &request, block_size)
.expect("Should fallback to selecting worker");
// Worker1 should be selected with 0 overlap
assert_eq!(result.worker_id, 1);
assert_eq!(result.overlap_blocks, 0);
}
#[test]
fn test_custom_weights() {
// Setup workers
let workers = create_workers(vec![
WorkerInfo {
id: 1,
usage: 0.50,
waiting: 1,
},
WorkerInfo {
id: 2,
usage: 0.80,
waiting: 0,
},
]);
// Custom config with high priority on GPU usage
let config = KvRouterConfig {
gpu_cache_usage_weight: 10.0, // Very high weight
overlap_score_weight: 2.0, // just current defaults
waiting_requests_weight: 1.0,
};
let selector = DefaultWorkerSelector::new(Some(config));
let request = create_request(
vec![
WorkerOverlap {
worker_id: 1,
overlap_blocks: 3,
},
WorkerOverlap {
worker_id: 2,
overlap_blocks: 4,
},
],
100,
);
let block_size = 20;
let result = selector
.select_worker(&workers, &request, block_size)
.expect("Should select worker");
assert_eq!(result.worker_id, 1);
}
// #[test]
// fn test_select_worker_basic() {
// // Setup workers
// let workers = create_workers(vec![
// WorkerInfo {
// id: 1,
// usage: 0.50,
// waiting: 1,
// },
// WorkerInfo {
// id: 2,
// usage: 0.80,
// waiting: 0,
// },
// ]);
// // Setup request: 100 tokens, block_size=20 (5 blocks)
// let request = create_request(
// vec![
// WorkerOverlap {
// worker_id: 1,
// overlap_blocks: 3,
// },
// WorkerOverlap {
// worker_id: 2,
// overlap_blocks: 4,
// },
// ],
// 100,
// );
// let selector = DefaultWorkerSelector::new(None);
// let block_size = 20;
// // Execute selection
// let result = selector
// .select_worker(&workers, &request, block_size)
// .expect("Should select a worker");
// // Worker 2 should win because:
// // Worker1: 2.0 * 0.600 - 1.0 * 0.500 - 1.0 * 1.000 = -0.3
// // Worker2: 2.0 * 0.800 - 1.0 * 0.800 - 1.0 * 0.000 = 0.8
// assert_eq!(result.worker_id, 2);
// assert_eq!(result.required_blocks, 5); // 100 tokens / 20 block_size
// assert_eq!(result.overlap_blocks, 4);
// }
// #[test]
// fn test_no_overlap_scores() {
// // Workers exist but request has no overlap scores
// let workers = create_workers(vec![WorkerInfo {
// id: 1,
// usage: 0.50,
// waiting: 1,
// }]);
// let request = create_request(vec![], 100); // No overlaps
// let selector = DefaultWorkerSelector::new(None);
// let block_size = 20;
// let result = selector
// .select_worker(&workers, &request, block_size)
// .expect("Should fallback to selecting worker");
// // Worker1 should be selected with 0 overlap
// assert_eq!(result.worker_id, 1);
// assert_eq!(result.overlap_blocks, 0);
// }
// #[test]
// fn test_custom_weights() {
// // Setup workers
// let workers = create_workers(vec![
// WorkerInfo {
// id: 1,
// usage: 0.50,
// waiting: 1,
// },
// WorkerInfo {
// id: 2,
// usage: 0.80,
// waiting: 0,
// },
// ]);
// // Custom config with high priority on GPU usage
// let config = KvRouterConfig {
// gpu_cache_usage_weight: 10.0, // Very high weight
// overlap_score_weight: 2.0, // just current defaults
// waiting_requests_weight: 1.0,
// };
// let selector = DefaultWorkerSelector::new(Some(config));
// let request = create_request(
// vec![
// WorkerOverlap {
// worker_id: 1,
// overlap_blocks: 3,
// },
// WorkerOverlap {
// worker_id: 2,
// overlap_blocks: 4,
// },
// ],
// 100,
// );
// let block_size = 20;
// let result = selector
// .select_worker(&workers, &request, block_size)
// .expect("Should select worker");
// assert_eq!(result.worker_id, 1);
// }
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment