Commit a544d823 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: more Pythonic kv router cleanups in examples (#396)

parent cce0c028
......@@ -83,6 +83,12 @@ class Router:
vllm_logger.info("Initializing Custom Router")
self.args = parse_args(self.__class__.__name__, "")
self.default_metrics = {
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
}
@async_on_start
async def async_init(self):
self.runtime = dynamo_context["runtime"]
......@@ -140,21 +146,13 @@ class Router:
)
worker_metrics = {}
# pull metrics for each worker
max_waiting = 0.0
if metrics:
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
worker_metrics[worker_id] = {
"gpu_cache_usage_perc": getattr(
endpoint, "gpu_cache_usage_perc", 0.0
),
"num_requests_waiting": getattr(
endpoint, "num_requests_waiting", 0.0
),
"gpu_prefix_cache_hit_rate": getattr(
endpoint, "gpu_prefix_cache_hit_rate", 0.0
),
key: getattr(endpoint, key, self.default_metrics[key])
for key in self.default_metrics.keys()
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
......@@ -168,14 +166,8 @@ class Router:
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(
worker_id,
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
metrics_dict = worker_metrics.get(worker_id, self.default_metrics)
gpu_cache_usage = metrics_dict["gpu_cache_usage_perc"]
normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
......@@ -185,15 +177,13 @@ class Router:
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = (
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
)
worker_logits[worker_id] = 2 * score - gpu_cache_usage - normalized_waiting
vllm_logger.info(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {gpu_cache_usage:.3f} - {normalized_waiting:.3f}"
)
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return ""
return "", 0.0
# Select the worker with the highest logit
max_logit = max(worker_logits.values())
......@@ -204,30 +194,26 @@ class Router:
# Log the metrics for the selected worker
if best_worker_id:
vllm_logger.info(
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
)
vllm_logger.info(
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
metrics_dict = worker_metrics.get(best_worker_id, self.default_metrics)
# Create log messages
log_messages = [
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}",
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}",
f"GPU Cache Hit Rate: {metrics_dict['gpu_prefix_cache_hit_rate']:.3f}",
f"GPU Cache Usage: {metrics_dict['gpu_cache_usage_perc']:.3f}",
f"Requests Waiting: {metrics_dict['num_requests_waiting']}",
]
metrics_dict = worker_metrics.get(best_worker_id, {})
vllm_logger.info(
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
)
vllm_logger.info(
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
)
vllm_logger.info(
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
)
# Log to vllm_logger
for message in log_messages:
vllm_logger.info(message)
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
@dynamo_endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = ""
try:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
......@@ -236,17 +222,12 @@ class Router:
scores = {}
vllm_logger.exception(f"Error finding matches: {e}")
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics()
schedule_result = self._cost_function(scores, metrics, token_length)
if schedule_result == "":
worker_id = ""
prefix_hit_rate = 0.0
else:
worker_id, prefix_hit_rate = schedule_result
worker_id, prefix_hit_rate = self._cost_function(
scores, metrics, len(request.tokens)
)
vllm_logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
yield f"{worker_id}_{prefix_hit_rate}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment