Commit 43ff14ce authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: KV router Pythonic cleanups (#324)

parent bb35f36f
......@@ -113,6 +113,23 @@ class Router:
metrics: AggregatedMetrics | None,
token_length: int,
):
"""The cost function for deciding the best worker to route a request to.
If there are multiple workers sharing the same optimal cost, then
one of them is randomly selected.
Args:
scores (OverlapScores | None): The number of matching blocks between
the request and the prefix cache of each worker.
metrics (AggregatedMetrics | None): Several worker metrics polled
by the `KvMetricsAggregator`, currently including the
GPU cache usage, number of waiting requests, and the
GPU prefix cache hit rate.
token_length (int): The number of tokens in the request.
Returns:
(str, float): The best worker id and the corresponding score.
"""
worker_scores = {}
if scores:
for worker_id, score in scores.scores.items():
......@@ -129,15 +146,15 @@ class Router:
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
worker_metrics[worker_id] = {
"gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
if hasattr(endpoint, "gpu_cache_usage_perc")
else 0.0,
"num_requests_waiting": endpoint.num_requests_waiting
if hasattr(endpoint, "num_requests_waiting")
else 0.0,
"gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
else 0.0,
"gpu_cache_usage_perc": getattr(
endpoint, "gpu_cache_usage_perc", 0.0
),
"num_requests_waiting": getattr(
endpoint, "num_requests_waiting", 0.0
),
"gpu_prefix_cache_hit_rate": getattr(
endpoint, "gpu_prefix_cache_hit_rate", 0.0
),
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
......@@ -179,14 +196,11 @@ class Router:
return ""
# Select the worker with the highest logit
if worker_logits:
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
else:
best_worker_id = ""
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
# Log the metrics for the selected worker
if best_worker_id:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment