Commit a544d823 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: more Pythonic kv router cleanups in examples (#396)

parent cce0c028
...@@ -83,6 +83,12 @@ class Router: ...@@ -83,6 +83,12 @@ class Router:
vllm_logger.info("Initializing Custom Router") vllm_logger.info("Initializing Custom Router")
self.args = parse_args(self.__class__.__name__, "") self.args = parse_args(self.__class__.__name__, "")
self.default_metrics = {
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
}
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
self.runtime = dynamo_context["runtime"] self.runtime = dynamo_context["runtime"]
...@@ -140,21 +146,13 @@ class Router: ...@@ -140,21 +146,13 @@ class Router:
) )
worker_metrics = {} worker_metrics = {}
# pull metrics for each worker
max_waiting = 0.0 max_waiting = 0.0
if metrics: if metrics:
for endpoint in metrics.endpoints: for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id worker_id = endpoint.worker_id
worker_metrics[worker_id] = { worker_metrics[worker_id] = {
"gpu_cache_usage_perc": getattr( key: getattr(endpoint, key, self.default_metrics[key])
endpoint, "gpu_cache_usage_perc", 0.0 for key in self.default_metrics.keys()
),
"num_requests_waiting": getattr(
endpoint, "num_requests_waiting", 0.0
),
"gpu_prefix_cache_hit_rate": getattr(
endpoint, "gpu_prefix_cache_hit_rate", 0.0
),
} }
max_waiting = max( max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"] max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
...@@ -168,14 +166,8 @@ class Router: ...@@ -168,14 +166,8 @@ class Router:
for worker_id in worker_ids: for worker_id in worker_ids:
# Use default values if worker not in scores or metrics # Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0) score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get( metrics_dict = worker_metrics.get(worker_id, self.default_metrics)
worker_id, gpu_cache_usage = metrics_dict["gpu_cache_usage_perc"]
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
normalized_waiting = ( normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting metrics_dict["num_requests_waiting"] / max_waiting
...@@ -185,15 +177,13 @@ class Router: ...@@ -185,15 +177,13 @@ class Router:
# Have 1 metric that weights towards cache hit # Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing # 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = ( worker_logits[worker_id] = 2 * score - gpu_cache_usage - normalized_waiting
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
)
vllm_logger.info( vllm_logger.info(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}" f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {gpu_cache_usage:.3f} - {normalized_waiting:.3f}"
) )
if not worker_logits or all(logit == 0 for logit in worker_logits.values()): if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return "" return "", 0.0
# Select the worker with the highest logit # Select the worker with the highest logit
max_logit = max(worker_logits.values()) max_logit = max(worker_logits.values())
...@@ -204,30 +194,26 @@ class Router: ...@@ -204,30 +194,26 @@ class Router:
# Log the metrics for the selected worker # Log the metrics for the selected worker
if best_worker_id: if best_worker_id:
vllm_logger.info( metrics_dict = worker_metrics.get(best_worker_id, self.default_metrics)
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
)
vllm_logger.info(
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
metrics_dict = worker_metrics.get(best_worker_id, {}) # Create log messages
vllm_logger.info( log_messages = [
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}" f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}",
) f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}",
vllm_logger.info( f"GPU Cache Hit Rate: {metrics_dict['gpu_prefix_cache_hit_rate']:.3f}",
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}" f"GPU Cache Usage: {metrics_dict['gpu_cache_usage_perc']:.3f}",
) f"Requests Waiting: {metrics_dict['num_requests_waiting']}",
vllm_logger.info( ]
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
) # Log to vllm_logger
for message in log_messages:
vllm_logger.info(message)
return best_worker_id, worker_scores.get(best_worker_id, 0.0) return best_worker_id, worker_scores.get(best_worker_id, 0.0)
@dynamo_endpoint() @dynamo_endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]: async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
lora_id = 0 lora_id = 0
worker_id = ""
try: try:
scores = await self.indexer.find_matches_for_request( scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id request.tokens, lora_id
...@@ -236,17 +222,12 @@ class Router: ...@@ -236,17 +222,12 @@ class Router:
scores = {} scores = {}
vllm_logger.exception(f"Error finding matches: {e}") vllm_logger.exception(f"Error finding matches: {e}")
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics() metrics = await self.metrics_aggregator.get_metrics()
schedule_result = self._cost_function(scores, metrics, token_length) worker_id, prefix_hit_rate = self._cost_function(
if schedule_result == "": scores, metrics, len(request.tokens)
worker_id = "" )
prefix_hit_rate = 0.0
else:
worker_id, prefix_hit_rate = schedule_result
vllm_logger.info( vllm_logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}" f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
) )
yield f"{worker_id}_{prefix_hit_rate}" yield f"{worker_id}_{prefix_hit_rate}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment