Commit 3f84cdad authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: add new metrics and simple router cost fn (#88)

parent 2153ee81
...@@ -112,11 +112,17 @@ fn mock_stats_handler(_stats: Stats) -> serde_json::Value { ...@@ -112,11 +112,17 @@ fn mock_stats_handler(_stats: Stats) -> serde_json::Value {
let request_active_slots = rand::thread_rng().gen_range(0..=request_total_slots); let request_active_slots = rand::thread_rng().gen_range(0..=request_total_slots);
let kv_total_blocks = 100; let kv_total_blocks = 100;
let kv_active_blocks = rand::thread_rng().gen_range(0..=kv_total_blocks); let kv_active_blocks = rand::thread_rng().gen_range(0..=kv_total_blocks);
let num_requests_waiting = rand::thread_rng().gen_range(0..=100);
let gpu_cache_usage_perc = rand::thread_rng().gen_range(0.0..=1.0);
let gpu_prefix_cache_hit_rate = rand::thread_rng().gen_range(0.0..=1.0);
let stats = ForwardPassMetrics { let stats = ForwardPassMetrics {
request_active_slots, request_active_slots,
request_total_slots, request_total_slots,
kv_active_blocks, kv_active_blocks,
kv_total_blocks, kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
}; };
println!("stats out: {:?}", stats); println!("stats out: {:?}", stats);
serde_json::to_value(stats).unwrap() serde_json::to_value(stats).unwrap()
......
...@@ -2990,7 +2990,7 @@ index 3cf1850e..ae006579 100644 ...@@ -2990,7 +2990,7 @@ index 3cf1850e..ae006579 100644
+ gpu_cache_usage_perc: float + gpu_cache_usage_perc: float
+ gpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..fe719642 100644 index 85b5f31e..05030292 100644
--- a/vllm/engine/multiprocessing/client.py --- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, @@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
...@@ -3084,7 +3084,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3084,7 +3084,7 @@ index 85b5f31e..fe719642 100644
@staticmethod @staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs): def is_unsupported_config(engine_args: AsyncEngineArgs):
# Pipeline parallel not yet supported # Pipeline parallel not yet supported
@@ -180,6 +210,63 @@ class MQLLMEngineClient(EngineClient): @@ -180,6 +210,61 @@ class MQLLMEngineClient(EngineClient):
except Exception as e: except Exception as e:
self._set_errored(e) self._set_errored(e)
...@@ -3118,21 +3118,19 @@ index 85b5f31e..fe719642 100644 ...@@ -3118,21 +3118,19 @@ index 85b5f31e..fe719642 100644
+ # Metrics received- check the message + # Metrics received- check the message
+ message: Frame = await self.metrics_socket.recv(copy=False) + message: Frame = await self.metrics_socket.recv(copy=False)
+ metrics = pickle.loads(message.buffer) + metrics = pickle.loads(message.buffer)
+ if self.metrics_publisher is not None: + if self.metrics_publisher is not None and isinstance(
+ if isinstance(metrics, KvMetrics): + metrics, KvMetrics
+ self.metrics_publisher.publish(metrics.request_active_slots, + ):
+ metrics.request_total_slots, + self.metrics_publisher.publish(metrics.request_active_slots,
+ metrics.kv_active_blocks, + metrics.request_total_slots,
+ metrics.kv_total_blocks, + metrics.kv_active_blocks,
+ metrics.num_requests_waiting, + metrics.kv_total_blocks,
+ metrics.gpu_cache_usage_perc, + metrics.num_requests_waiting,
+ metrics.gpu_prefix_cache_hit_rate) + metrics.gpu_cache_usage_perc,
+ if isinstance(metrics, Stats): + metrics.gpu_prefix_cache_hit_rate)
+ # TODO + logger.debug("Metrics successful.")
+ # Send the whole stats to user +
+ pass + # TODO: Investigate sending whole stats object
+
+ logger.debug("Metrics successful.")
+ +
+ except asyncio.CancelledError: + except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.") + logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
...@@ -3148,7 +3146,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3148,7 +3146,7 @@ index 85b5f31e..fe719642 100644
async def run_output_handler_loop(self): async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues""" """Get RequestOutputs from Engine and stream to Request Queues"""
@@ -278,12 +365,26 @@ class MQLLMEngineClient(EngineClient): @@ -278,12 +363,26 @@ class MQLLMEngineClient(EngineClient):
# Wait until server is ready. # Wait until server is ready.
response = await self._wait_for_server_rpc(socket) response = await self._wait_for_server_rpc(socket)
...@@ -3175,7 +3173,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3175,7 +3173,7 @@ index 85b5f31e..fe719642 100644
def close(self): def close(self):
"""Destroy the ZeroMQ Context.""" """Destroy the ZeroMQ Context."""
@@ -293,6 +394,8 @@ class MQLLMEngineClient(EngineClient): @@ -293,6 +392,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks. # Cancel background tasks.
if self.health_loop is not None: if self.health_loop is not None:
self.health_loop.cancel() self.health_loop.cancel()
...@@ -3184,7 +3182,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3184,7 +3182,7 @@ index 85b5f31e..fe719642 100644
if self.output_loop is not None: if self.output_loop is not None:
self.output_loop.cancel() self.output_loop.cancel()
@@ -415,6 +518,9 @@ class MQLLMEngineClient(EngineClient): @@ -415,6 +516,9 @@ class MQLLMEngineClient(EngineClient):
""" """
if self._errored_with is not None: if self._errored_with is not None:
raise self._errored_with raise self._errored_with
...@@ -3194,7 +3192,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3194,7 +3192,7 @@ index 85b5f31e..fe719642 100644
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@@ -473,6 +579,7 @@ class MQLLMEngineClient(EngineClient): @@ -473,6 +577,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
...@@ -3202,7 +3200,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3202,7 +3200,7 @@ index 85b5f31e..fe719642 100644
*, *,
inputs: Optional[PromptType] = None # DEPRECATED inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
@@ -502,7 +609,8 @@ class MQLLMEngineClient(EngineClient): @@ -502,7 +607,8 @@ class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id, return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers,
...@@ -3212,7 +3210,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3212,7 +3210,7 @@ index 85b5f31e..fe719642 100644
@overload @overload
def encode( def encode(
@@ -586,6 +694,7 @@ class MQLLMEngineClient(EngineClient): @@ -586,6 +692,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
...@@ -3220,7 +3218,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3220,7 +3218,7 @@ index 85b5f31e..fe719642 100644
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]: PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -630,6 +739,12 @@ class MQLLMEngineClient(EngineClient): @@ -630,6 +737,12 @@ class MQLLMEngineClient(EngineClient):
else: else:
lp_bytes = None lp_bytes = None
...@@ -3233,7 +3231,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3233,7 +3231,7 @@ index 85b5f31e..fe719642 100644
request_bytes = pickle.dumps( request_bytes = pickle.dumps(
RPCProcessRequest( RPCProcessRequest(
prompt=prompt, prompt=prompt,
@@ -639,11 +754,11 @@ class MQLLMEngineClient(EngineClient): @@ -639,11 +752,11 @@ class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
...@@ -3247,7 +3245,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3247,7 +3245,7 @@ index 85b5f31e..fe719642 100644
await self.input_socket.send_multipart(parts, copy=False) await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note # 4) Stream the RequestOutputs from the output queue. Note
@@ -705,3 +820,6 @@ class MQLLMEngineClient(EngineClient): @@ -705,3 +818,6 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None # Raise on error, otherwise happily return None
if isinstance(request_output, BaseException): if isinstance(request_output, BaseException):
raise request_output raise request_output
...@@ -3255,7 +3253,7 @@ index 85b5f31e..fe719642 100644 ...@@ -3255,7 +3253,7 @@ index 85b5f31e..fe719642 100644
+ def set_metrics_publisher(self, metrics_publisher): + def set_metrics_publisher(self, metrics_publisher):
+ self.metrics_publisher = metrics_publisher + self.metrics_publisher = metrics_publisher
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
index a0dd7958..3204cfb8 100644 index a0dd7958..c82bc15b 100644
--- a/vllm/engine/multiprocessing/engine.py --- a/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py
@@ -3,35 +3,115 @@ @@ -3,35 +3,115 @@
...@@ -3355,23 +3353,23 @@ index a0dd7958..3204cfb8 100644 ...@@ -3355,23 +3353,23 @@ index a0dd7958..3204cfb8 100644
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) + self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+ +
+# TODO: Send entire stats object to the client +# TODO: Send entire stats object to the client
+class StatLogger(StatLoggerBase): +# class StatLogger(StatLoggerBase):
+ def __init__( +# def __init__(
+ self, +# self,
+ metrics_socket +# metrics_socket
+ ): +# ):
+ self.metrics_socket = metrics_socket +# self.metrics_socket = metrics_socket
+ +
+ def log(self, stats: Stats) -> None: +# def log(self, stats: Stats) -> None:
+ self._send_metrics(stats) +# self._send_metrics(stats)
+ +
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None: +# def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass +# pass
+ +
+ def _send_metrics(self, stats: Stats): +# def _send_metrics(self, stats: Stats):
+ if not self.metrics_socket.closed: +# if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(stats) +# metrics_bytes = pickle.dumps(stats)
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) +# self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+ +
+ +
+ +
...@@ -3379,7 +3377,7 @@ index a0dd7958..3204cfb8 100644 ...@@ -3379,7 +3377,7 @@ index a0dd7958..3204cfb8 100644
class MQLLMEngine: class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`. """A multiprocessing wrapper for :class:`LLMEngine`.
@@ -94,12 +174,35 @@ class MQLLMEngine: @@ -94,12 +174,37 @@ class MQLLMEngine:
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
...@@ -3406,16 +3404,18 @@ index a0dd7958..3204cfb8 100644 ...@@ -3406,16 +3404,18 @@ index a0dd7958..3204cfb8 100644
+ self.engine.cache_config.num_gpu_blocks, + self.engine.cache_config.num_gpu_blocks,
+ self.metrics_socket + self.metrics_socket
+ ) + )
+ self.general_stat_logger = StatLogger(
+ self.metrics_socket
+ )
+ self.engine.add_logger("kv_metrics", self.kv_stat_logger) + self.engine.add_logger("kv_metrics", self.kv_stat_logger)
+ self.engine.add_logger("general_metrics", self.general_stat_logger) +
+ # TODO investigate sending whole stats object
+ # self.general_stat_logger = StatLogger(
+ # self.metrics_socket
+ # )
+ # self.engine.add_logger("general_metrics", self.general_stat_logger)
+ +
@property @property
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
if self._errored_with is not None: if self._errored_with is not None:
@@ -171,8 +274,17 @@ class MQLLMEngine: @@ -171,8 +276,17 @@ class MQLLMEngine:
# Handle the query from the Client. # Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY: if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled() tracing_enabled = self.engine.is_tracing_enabled()
...@@ -3435,7 +3435,7 @@ index a0dd7958..3204cfb8 100644 ...@@ -3435,7 +3435,7 @@ index a0dd7958..3204cfb8 100644
except Exception as e: except Exception as e:
response = e response = e
@@ -185,6 +297,7 @@ class MQLLMEngine: @@ -185,6 +299,7 @@ class MQLLMEngine:
while True: while True:
if not self.engine.has_unfinished_requests(): if not self.engine.has_unfinished_requests():
...@@ -3443,7 +3443,7 @@ index a0dd7958..3204cfb8 100644 ...@@ -3443,7 +3443,7 @@ index a0dd7958..3204cfb8 100644
# Poll until there is work to do. # Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send # When there's no work, check on engine health and send
@@ -220,6 +333,13 @@ class MQLLMEngine: @@ -220,6 +335,13 @@ class MQLLMEngine:
def handle_new_input(self): def handle_new_input(self):
"""Handle new input from the socket""" """Handle new input from the socket"""
try: try:
...@@ -3457,7 +3457,7 @@ index a0dd7958..3204cfb8 100644 ...@@ -3457,7 +3457,7 @@ index a0dd7958..3204cfb8 100644
while self.input_socket.poll(timeout=0) != 0: while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False) frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer) request = pickle.loads(frames[0].buffer)
@@ -262,6 +382,11 @@ class MQLLMEngine: @@ -262,6 +384,11 @@ class MQLLMEngine:
self._send_outputs(rpc_err) self._send_outputs(rpc_err)
try: try:
...@@ -3469,7 +3469,7 @@ index a0dd7958..3204cfb8 100644 ...@@ -3469,7 +3469,7 @@ index a0dd7958..3204cfb8 100644
self.engine.add_request( self.engine.add_request(
request_id=request_id, request_id=request_id,
prompt=request.prompt, prompt=request.prompt,
@@ -269,7 +394,9 @@ class MQLLMEngine: @@ -269,7 +396,9 @@ class MQLLMEngine:
lora_request=request.lora_request, lora_request=request.lora_request,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request, prompt_adapter_request=request.prompt_adapter_request,
......
...@@ -237,7 +237,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>] ...@@ -237,7 +237,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>]
Example: Example:
```bash ```bash
# Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model # Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model
bash /workspace/examples/python_rs/llm/vllm/scripts/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B bash /workspace/examples/python_rs/llm/vllm/scripts/kv-router-run.sh 8 test deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# List tmux sessions # List tmux sessions
tmux ls tmux ls
......
...@@ -15,69 +15,20 @@ ...@@ -15,69 +15,20 @@
import asyncio import asyncio
import random
from argparse import Namespace from argparse import Namespace
from enum import Enum
from typing import AsyncIterator from typing import AsyncIterator
import uvloop import uvloop
from common.protocol import Tokens from common.protocol import Tokens
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
from dynamo.llm import KvIndexer, KvMetricsAggregator, KvRouter from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
WorkerId = str WorkerId = str
class RoutingStrategy(Enum):
PREFIX = "prefix"
ROUND_ROBIN = "round_robin"
RANDOM = "random"
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
router: KvRouter,
routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX,
):
vllm_logger.info(
f"Initializing KV Router with strategy: {routing_strategy.value}"
)
self.router = router
self.routing_strategy = routing_strategy
@dynamo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = None
if self.routing_strategy == RoutingStrategy.PREFIX:
try:
worker_id = await self.router.schedule(request.tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
yield str(worker_id)
else:
# TODO: Do we implement round_robin and random here?
# or just skip this router and directly enable in preprocess?
raise NotImplementedError(
f"Routing strategy {self.routing_strategy} not implemented"
)
class CustomRouter: class CustomRouter:
""" """
Request handler for the generate endpoint Request handler for the generate endpoint
...@@ -85,28 +36,117 @@ class CustomRouter: ...@@ -85,28 +36,117 @@ class CustomRouter:
def __init__( def __init__(
self, self,
workers_client,
indexer: KvIndexer, indexer: KvIndexer,
metrics_aggregator: KvMetricsAggregator, metrics_aggregator: KvMetricsAggregator,
): ):
vllm_logger.info("Initializing Custom Router")
self.indexer = indexer self.indexer = indexer
self.metrics_aggregator = metrics_aggregator self.metrics_aggregator = metrics_aggregator
self.workers_client = workers_client
def _cost_function(self, scores, metrics): def _cost_function(
# naive cost function for demonstration purposes self,
current_best = ("", 0) scores: OverlapScores | None,
for worker_id, score in scores.scores.items(): metrics: AggregatedMetrics | None,
if score > current_best[1]: token_length: int,
current_best = (worker_id, score) ):
for endpoint in metrics.endpoints: worker_scores = {}
if endpoint.worker_id == current_best[0]: if scores:
print(f"Metrics of endpoint: {endpoint.worker_id}") for worker_id, score in scores.scores.items():
print( # score is number of matching blocks we multiply by block_size to get tokens
f"request slot usage: {endpoint.request_active_slots} / {endpoint.request_total_slots}" # and compare to token_length. The larger the cache hit the better
) worker_scores[worker_id] = (
print( score * self.indexer.block_size() / token_length
f"KV block usage: {endpoint.kv_active_blocks} / {endpoint.kv_total_blocks}"
) )
return current_best[0]
worker_metrics = {}
# pull metrics for each worker
max_waiting = 0.0
if metrics:
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
worker_metrics[worker_id] = {
"gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
if hasattr(endpoint, "gpu_cache_usage_perc")
else 0.0,
"num_requests_waiting": endpoint.num_requests_waiting
if hasattr(endpoint, "num_requests_waiting")
else 0.0,
"gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
else 0.0,
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
)
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.endpoint_ids()
worker_logits = {}
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(
worker_id,
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
if max_waiting > 0
else 0.0
)
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = (
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
)
vllm_logger.info(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
)
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return ""
# Select the worker with the highest logit
if worker_logits:
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
else:
best_worker_id = ""
# Log the metrics for the selected worker
if best_worker_id:
vllm_logger.info(
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
)
vllm_logger.info(
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
metrics_dict = worker_metrics.get(best_worker_id, {})
vllm_logger.info(
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
)
vllm_logger.info(
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
)
vllm_logger.info(
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
)
return best_worker_id
@dynamo_endpoint(Tokens, WorkerId) @dynamo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]: async def generate(self, request) -> AsyncIterator[WorkerId]:
...@@ -116,18 +156,16 @@ class CustomRouter: ...@@ -116,18 +156,16 @@ class CustomRouter:
scores = await self.indexer.find_matches_for_request( scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id request.tokens, lora_id
) )
metrics = await self.metrics_aggregator.get_metrics()
worker_id = self._cost_function(scores, metrics)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e: except Exception as e:
vllm_logger.info(f"{e}") scores = {}
worker_id = "" vllm_logger.exception(f"Error finding matches: {e}")
vllm_logger.exception(f"Error during worker selection: {e}")
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics()
worker_id = self._cost_function(scores, metrics, token_length)
vllm_logger.info(f"Scheduling to worker_id: {worker_id}") vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
vllm_logger.info("########")
yield str(worker_id) yield str(worker_id)
...@@ -144,14 +182,6 @@ async def worker(runtime: DistributedRuntime, args: Namespace): ...@@ -144,14 +182,6 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
.endpoint("generate") .endpoint("generate")
.client() .client()
) )
wait_task = workers_client.wait_for_endpoints()
await asyncio.sleep(1)
while not wait_task.done():
vllm_logger.info("Waiting for workers to be ready...")
await asyncio.sleep(5)
wait_task.result()
while len(workers_client.endpoint_ids()) < args.min_workers: while len(workers_client.endpoint_ids()) < args.min_workers:
vllm_logger.info( vllm_logger.info(
...@@ -172,23 +202,11 @@ async def worker(runtime: DistributedRuntime, args: Namespace): ...@@ -172,23 +202,11 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
endpoint = router_component.endpoint("generate") endpoint = router_component.endpoint("generate")
if args.custom_router: indexer = KvIndexer(kv_listener, args.block_size)
# @REVIEWER - I'm not currently checking if block size matches that of the engine metrics_aggregator = KvMetricsAggregator(kv_listener)
# If they don't match things will silently fail await endpoint.serve_endpoint(
# The preferred solution would be for the KV Indexer to read from the MDC in etcd and not bother the user at all CustomRouter(workers_client, indexer, metrics_aggregator).generate
# The second solution would be to do KvIndexer(kv_listener, MDC.block_size) )
# as this ensures block size matches that of the engine
# In this case we need to do some sort of handshake or check in case a user just puts in a random block size
indexer = KvIndexer(kv_listener, args.block_size)
metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint(
CustomRouter(indexer, metrics_aggregator).generate
)
else:
# TODO Read block_size from MDC
router = KvRouter(runtime, kv_listener, args.block_size)
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -197,35 +215,18 @@ if __name__ == "__main__": ...@@ -197,35 +215,18 @@ if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--routing-strategy",
type=RoutingStrategy,
default=RoutingStrategy.PREFIX,
choices=list(RoutingStrategy),
help="Routing strategy to use",
)
parser.add_argument( parser.add_argument(
"--min-workers", "--min-workers",
type=int, type=int,
default=1, default=1,
help="Minimum number of workers required before proceeding", help="Minimum number of workers required before proceeding",
) )
parser.add_argument( # TODO: Read block size
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
parser.add_argument( parser.add_argument(
"--block-size", "--block-size",
type=int, type=int,
help="KV block size", default=64,
) help="Block size for the KV Indexer",
parser.add_argument(
"--custom-router",
type=bool,
default=False,
help="Whether to use custom router or not",
) )
args = parser.parse_args() args = parser.parse_args()
......
...@@ -99,10 +99,13 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -99,10 +99,13 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
# Initially send dummy metrics to kick start, # Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered # vLLM will not update stat until forward pass is triggered
metrics_publisher.publish( metrics_publisher.publish(
0, 0, # request_active_slots
1024, 1024, # request_total_slots
0, 0, # kv_active_blocks
1024, 1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
) )
await asyncio.gather( await asyncio.gather(
......
...@@ -18,53 +18,78 @@ ...@@ -18,53 +18,78 @@
# - Must use a single GPU for workers as CUDA_VISIBLE_DEVICES is set to a fixed value # - Must use a single GPU for workers as CUDA_VISIBLE_DEVICES is set to a fixed value
# - Must use a single node # - Must use a single node
if [ $# -lt 2 ]; then if [ $# -lt 3 ]; then
echo "Usage: $0 <number_of_workers> <routing_strategy> [model_name] [endpoint_name]" echo "Usage: $0 <number_of_workers> <log_dir_name> [model_name] [model_args] [chat_endpoint_name] [completions_endpoint_name]"
echo "Error: Must specify at least number of workers and routing strategy" echo "Error: Must specify at least number of workers, log_dir_name"
echo "Optional: model_name (default: deepseek-ai/DeepSeek-R1-Distill-Llama-8B)" echo "Optional: model_name (default: deepseek-ai/DeepSeek-R1-Distill-Llama-8B)"
echo "Optional: endpoint_name (default: dynamo.process.chat/completions)" echo "Optional: model_args (quoted string with model arguments)"
echo "Optional: chat_endpoint_name (default: dynamo.process.chat/completions)"
echo "Optional: completions_endpoint_name (default: dynamo.process.completions)"
exit 1 exit 1
fi fi
# Uncomment if using Cache
# export HF_HUB_OFFLINE=1
# https://github.com/vllm-project/vllm/issues/10734#issuecomment-2507201353
# Fix for:torch.distributed.DistBackendError: File name too long
# export GLOO_SOCKET_IFNAME=lo
NUM_WORKERS=$1 NUM_WORKERS=$1
ROUTING_STRATEGY=$2 LOG_DIR_NAME=$2
MODEL_NAME=${3:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"} MODEL_NAME=${3:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
ENDPOINT_NAME=${4:-"dynamo.process.chat/completions"} CUSTOM_MODEL_ARGS=$4
VALID_STRATEGIES=("prefix") CHAT_ENDPOINT_NAME=${5:-"dynamo.process.chat/completions"}
COMPLETIONS_ENDPOINT_NAME=${6:-"dynamo.process.completions"}
SESSION_NAME="v" SESSION_NAME="v"
WORKDIR="/workspace/examples/python_rs/llm/vllm" WORKDIR="/workspace/examples/python_rs/llm/vllm"
INIT_CMD="cd $WORKDIR" INIT_CMD="cd $WORKDIR"
if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then
echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}" # Default model args
exit 1 DEFAULT_MODEL_ARGS="--model $MODEL_NAME \
--tokenizer $MODEL_NAME \
--enable-prefix-caching \
--block-size 64"
# Use custom model args if provided, otherwise use default
if [ -n "$CUSTOM_MODEL_ARGS" ]; then
MODEL_ARGS="$CUSTOM_MODEL_ARGS"
echo "Using custom model arguments"
else
MODEL_ARGS="$DEFAULT_MODEL_ARGS"
echo "Using default model arguments"
fi fi
# Create logs directory if it doesn't exist
LOGS_DIR="/logs/$LOG_DIR_NAME"
mkdir -p $LOGS_DIR
chmod -R 775 $LOGS_DIR
######################################################## ########################################################
# HTTP Server # HTTP Server
######################################################## ########################################################
HTTP_CMD="DYN_LOG=DEBUG http" HTTP_CMD="DYN_LOG=DEBUG http |& tee $LOGS_DIR/http.log"
tmux new-session -d -s "$SESSION_NAME-http" tmux new-session -d -s "$SESSION_NAME-http"
tmux send-keys -t "$SESSION_NAME-http" "$INIT_CMD && $HTTP_CMD" C-m tmux send-keys -t "$SESSION_NAME-http" "$INIT_CMD && $HTTP_CMD" C-m
######################################################## ########################################################
# LLMCTL # LLMCTL
######################################################## ########################################################
LLMCTL_CMD="sleep 5 && llmctl http remove chat-model $MODEL_NAME && \ LLMCTL_CMD="sleep 5 && \
llmctl http add chat-model $MODEL_NAME $ENDPOINT_NAME && \ llmctl http remove chat $MODEL_NAME && \
llmctl http list chat-model" llmctl http remove completions $MODEL_NAME && \
llmctl http add chat $MODEL_NAME $CHAT_ENDPOINT_NAME && \
llmctl http add completions $MODEL_NAME $COMPLETIONS_ENDPOINT_NAME && \
llmctl http list |& tee $LOGS_DIR/llmctl.log"
tmux new-session -d -s "$SESSION_NAME-llmctl" tmux new-session -d -s "$SESSION_NAME-llmctl"
tmux send-keys -t "$SESSION_NAME-llmctl" "$INIT_CMD && $LLMCTL_CMD" C-m tmux send-keys -t "$SESSION_NAME-llmctl" "$INIT_CMD && $LLMCTL_CMD" C-m
######################################################## ########################################################
# Processor # Processor
######################################################## ########################################################
# For now processor gets same args as worker, need to have them communicate over etcd PROCESSOR_CMD="RUST_LOG=info python3 -m kv_router.processor $MODEL_ARGS |& tee $LOGS_DIR/processor.log"
PROCESSOR_CMD="RUST_LOG=info python3 -m kv_router.processor \
--model $MODEL_NAME \
--tokenizer $MODEL_NAME \
--enable-prefix-caching \
--block-size 32 \
--max-model-len 16384 "
tmux new-session -d -s "$SESSION_NAME-processor" tmux new-session -d -s "$SESSION_NAME-processor"
tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
...@@ -72,10 +97,7 @@ tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m ...@@ -72,10 +97,7 @@ tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
# Router # Router
######################################################## ########################################################
ROUTER_CMD="RUST_LOG=info python3 -m kv_router.router \ ROUTER_CMD="RUST_LOG=info python3 -m kv_router.router \
--model $MODEL_NAME \ --min-workers $NUM_WORKERS |& tee $LOGS_DIR/router.log"
--routing-strategy $ROUTING_STRATEGY \
--min-workers $NUM_WORKERS \
--block-size 32"
tmux new-session -d -s "$SESSION_NAME-router" tmux new-session -d -s "$SESSION_NAME-router"
tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m
...@@ -83,17 +105,12 @@ tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m ...@@ -83,17 +105,12 @@ tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m
######################################################## ########################################################
# Workers # Workers
######################################################## ########################################################
WORKER_CMD="RUST_LOG=info python3 -m kv_router.worker \ WORKER_CMD="RUST_LOG=info python3 -m kv_router.worker $MODEL_ARGS"
--model $MODEL_NAME \
--tokenizer $MODEL_NAME \
--enable-prefix-caching \
--block-size 64 \
--max-model-len 16384 "
for i in $(seq 1 $NUM_WORKERS); do for i in $(seq 1 $NUM_WORKERS); do
tmux new-session -d -s "$SESSION_NAME-$i" tmux new-session -d -s "$SESSION_NAME-$i"
done done
for i in $(seq 1 $NUM_WORKERS); do for i in $(seq 1 $NUM_WORKERS); do
tmux send-keys -t "$SESSION_NAME-$i" "$INIT_CMD && CUDA_VISIBLE_DEVICES=$((i-1)) $WORKER_CMD" C-m tmux send-keys -t "$SESSION_NAME-$i" "$INIT_CMD && CUDA_VISIBLE_DEVICES=$((i-1)) $WORKER_CMD |& tee $LOGS_DIR/worker-$i.log" C-m
done done
\ No newline at end of file
...@@ -228,6 +228,10 @@ async fn add_model( ...@@ -228,6 +228,10 @@ async fn add_model(
endpoint_name endpoint_name
); );
if model_name.starts_with('/') {
raise!("Model name '{}' cannot start with a slash", model_name);
}
let parts: Vec<&str> = endpoint_name.split('.').collect(); let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 { if parts.len() < 2 {
......
...@@ -90,6 +90,7 @@ impl KvMetricsPublisher { ...@@ -90,6 +90,7 @@ impl KvMetricsPublisher {
}) })
} }
#[allow(clippy::too_many_arguments)]
fn publish( fn publish(
&self, &self,
_py: Python, _py: Python,
...@@ -97,6 +98,9 @@ impl KvMetricsPublisher { ...@@ -97,6 +98,9 @@ impl KvMetricsPublisher {
request_total_slots: u64, request_total_slots: u64,
kv_active_blocks: u64, kv_active_blocks: u64,
kv_total_blocks: u64, kv_total_blocks: u64,
num_requests_waiting: u64,
gpu_cache_usage_perc: f32,
gpu_prefix_cache_hit_rate: f32,
) -> PyResult<()> { ) -> PyResult<()> {
self.inner self.inner
.publish( .publish(
...@@ -105,6 +109,9 @@ impl KvMetricsPublisher { ...@@ -105,6 +109,9 @@ impl KvMetricsPublisher {
request_total_slots, request_total_slots,
kv_active_blocks, kv_active_blocks,
kv_total_blocks, kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
} }
.into(), .into(),
) )
...@@ -180,6 +187,10 @@ impl KvIndexer { ...@@ -180,6 +187,10 @@ impl KvIndexer {
}) })
} }
fn block_size(&self) -> usize {
self.inner.block_size()
}
fn find_matches_for_request<'p>( fn find_matches_for_request<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
...@@ -212,6 +223,12 @@ pub(crate) struct EndpointKvMetrics { ...@@ -212,6 +223,12 @@ pub(crate) struct EndpointKvMetrics {
pub kv_active_blocks: u64, pub kv_active_blocks: u64,
#[pyo3(get, set)] #[pyo3(get, set)]
pub kv_total_blocks: u64, pub kv_total_blocks: u64,
#[pyo3(get, set)]
pub num_requests_waiting: u64,
#[pyo3(get, set)]
pub gpu_cache_usage_perc: f32,
#[pyo3(get, set)]
pub gpu_prefix_cache_hit_rate: f32,
} }
#[pyclass] #[pyclass]
...@@ -258,6 +275,9 @@ impl KvMetricsAggregator { ...@@ -258,6 +275,9 @@ impl KvMetricsAggregator {
request_total_slots: x.data.request_total_slots, request_total_slots: x.data.request_total_slots,
kv_active_blocks: x.data.kv_active_blocks, kv_active_blocks: x.data.kv_active_blocks,
kv_total_blocks: x.data.kv_total_blocks, kv_total_blocks: x.data.kv_total_blocks,
num_requests_waiting: x.data.num_requests_waiting,
gpu_cache_usage_perc: x.data.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: x.data.gpu_prefix_cache_hit_rate,
}) })
.collect(); .collect();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
......
...@@ -242,7 +242,11 @@ class KvMetricsPublisher: ...@@ -242,7 +242,11 @@ class KvMetricsPublisher:
def publish(self, request_active_slots: int, def publish(self, request_active_slots: int,
request_total_slots: int, request_total_slots: int,
kv_active_blocks: int, kv_active_blocks: int,
kv_total_blocks: int) -> None: kv_total_blocks: int,
num_requests_waiting: int,
gpu_cache_usage_perc: float,
gpu_prefix_cache_hit_rate: float
) -> None:
""" """
Update the KV metrics being reported. Update the KV metrics being reported.
""" """
...@@ -298,7 +302,7 @@ class KvIndexer: ...@@ -298,7 +302,7 @@ class KvIndexer:
... ...
def __init__(self, component: Component) -> None: def __init__(self, component: Component, block_size: int) -> None:
""" """
Create a `KvIndexer` object Create a `KvIndexer` object
""" """
...@@ -309,6 +313,12 @@ class KvIndexer: ...@@ -309,6 +313,12 @@ class KvIndexer:
""" """
... ...
def block_size(self) -> int:
"""
Return the block size of the KV Indexer.
"""
...
class AggregatedMetrics: class AggregatedMetrics:
""" """
A collection of metrics of the endpoints A collection of metrics of the endpoints
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dynamo._core import AggregatedMetrics as AggregatedMetrics
from dynamo._core import DisaggregatedRouter as DisaggregatedRouter from dynamo._core import DisaggregatedRouter as DisaggregatedRouter
from dynamo._core import HttpAsyncEngine as HttpAsyncEngine from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpError as HttpError from dynamo._core import HttpError as HttpError
...@@ -21,3 +22,4 @@ from dynamo._core import KvIndexer as KvIndexer ...@@ -21,3 +22,4 @@ from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvMetricsPublisher as KvMetricsPublisher from dynamo._core import KvMetricsPublisher as KvMetricsPublisher
from dynamo._core import KvRouter as KvRouter from dynamo._core import KvRouter as KvRouter
from dynamo._core import OverlapScores as OverlapScores
...@@ -193,6 +193,9 @@ async def test_metrics_aggregator(distributed_runtime): ...@@ -193,6 +193,9 @@ async def test_metrics_aggregator(distributed_runtime):
"request_total_slots": 1024, "request_total_slots": 1024,
"kv_active_blocks": 523, "kv_active_blocks": 523,
"kv_total_blocks": 777, "kv_total_blocks": 777,
"num_requests_waiting": 10,
"gpu_cache_usage_perc": 0.5,
"gpu_prefix_cache_hit_rate": 0.75,
} }
# need 'create_task' to put publisher task in the background # need 'create_task' to put publisher task in the background
...@@ -222,5 +225,8 @@ async def metrics_publisher_task(kv_listener, expected_metrics): ...@@ -222,5 +225,8 @@ async def metrics_publisher_task(kv_listener, expected_metrics):
expected_metrics["request_total_slots"], expected_metrics["request_total_slots"],
expected_metrics["kv_active_blocks"], expected_metrics["kv_active_blocks"],
expected_metrics["kv_total_blocks"], expected_metrics["kv_total_blocks"],
expected_metrics["num_requests_waiting"],
expected_metrics["gpu_cache_usage_perc"],
expected_metrics["gpu_prefix_cache_hit_rate"],
) )
await metrics_publisher.create_endpoint(kv_listener) await metrics_publisher.create_endpoint(kv_listener)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
use anyhow::Result; use anyhow::Result;
use dynamo_runtime::{component::Component, component::Namespace, DistributedRuntime}; use dynamo_runtime::{component::Component, component::Namespace, DistributedRuntime};
use futures::stream::StreamExt; use futures::stream::StreamExt;
use std::{sync::Arc, time::Duration}; use std::sync::Arc;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing; use tracing;
...@@ -29,7 +29,8 @@ pub mod scoring; ...@@ -29,7 +29,8 @@ pub mod scoring;
use crate::kv_router::{ use crate::kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
scheduler::{Endpoint, KvScheduler, Service}, metrics_aggregator::collect_endpoints,
scheduler::KvScheduler,
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
}; };
...@@ -148,73 +149,3 @@ impl KvRouter { ...@@ -148,73 +149,3 @@ impl KvRouter {
Ok(worker_id) Ok(worker_id)
} }
} }
async fn collect_endpoints(
nats_client: dynamo_runtime::transports::nats::Client,
service_name: String,
ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>,
cancel: CancellationToken,
) {
loop {
tokio::select! {
_ = cancel.cancelled() => {
tracing::debug!("cancellation token triggered");
break;
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
tracing::trace!("collecting endpoints for service: {}", service_name);
}
}
let values = match nats_client
.get_endpoints(&service_name, Duration::from_secs(1))
.await
{
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
continue;
}
};
tracing::debug!("values: {:?}", values);
let services: Vec<Service> = values
.into_iter()
.filter(|v| !v.is_empty())
.filter_map(|v| match serde_json::from_slice::<Service>(&v) {
Ok(service) => Some(service),
Err(e) => {
tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
None
}
})
.collect();
tracing::debug!("services: {:?}", services);
let endpoints: Vec<Endpoint> = services
.into_iter()
.flat_map(|s| s.endpoints)
.filter(|s| s.data.is_some())
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: s.data.unwrap(),
})
.collect();
tracing::debug!("endpoints: {:?}", endpoints);
tracing::trace!(
"found {} endpoints for service: {}",
endpoints.len(),
service_name
);
let processed = ProcessedEndpoints::new(endpoints);
// process endpoints into
if ep_tx.send(processed).await.is_err() {
tracing::trace!("failed to send processed endpoints; shutting down");
break;
}
}
}
...@@ -588,6 +588,10 @@ impl KvIndexer { ...@@ -588,6 +588,10 @@ impl KvIndexer {
} }
} }
pub fn block_size(&self) -> usize {
self.kv_block_size
}
pub fn new(token: CancellationToken, kv_block_size: usize) -> Self { pub fn new(token: CancellationToken, kv_block_size: usize) -> Self {
Self::new_with_frequency(token, None, kv_block_size) Self::new_with_frequency(token, None, kv_block_size)
} }
...@@ -775,6 +779,10 @@ impl KvIndexerSharded { ...@@ -775,6 +779,10 @@ impl KvIndexerSharded {
} }
} }
pub fn block_size(&self) -> usize {
self.kv_block_size
}
pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self { pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self {
Self::new_with_frequency(token, num_shards, None, kv_block_size) Self::new_with_frequency(token, num_shards, None, kv_block_size)
} }
......
...@@ -80,70 +80,71 @@ impl KvMetricsAggregator { ...@@ -80,70 +80,71 @@ impl KvMetricsAggregator {
} }
} }
async fn collect_endpoints( pub async fn collect_endpoints(
nats_client: dynamo_runtime::transports::nats::Client, nats_client: dynamo_runtime::transports::nats::Client,
service_name: String, service_name: String,
ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>, ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>,
cancel: CancellationToken, cancel: CancellationToken,
) { ) {
let backoff_delay = Duration::from_millis(100);
loop { loop {
tokio::select! { tokio::select! {
_ = cancel.cancelled() => { _ = cancel.cancelled() => {
tracing::debug!("cancellation token triggered"); tracing::debug!("cancellation token triggered");
break; break;
} }
_ = tokio::time::sleep(Duration::from_secs(1)) => { _ = tokio::time::sleep(backoff_delay) => {
tracing::trace!("collecting endpoints for service: {}", service_name); tracing::trace!("collecting endpoints for service: {}", service_name);
} let values = match nats_client
} .get_endpoints(&service_name, Duration::from_millis(300))
.await
let values = match nats_client {
.get_endpoints(&service_name, Duration::from_secs(1)) Ok(v) => v,
.await Err(e) => {
{ tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
Ok(v) => v, continue;
Err(e) => { }
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e); };
continue;
}
};
tracing::debug!("values: {:?}", values); tracing::debug!("values: {:?}", values);
let services: Vec<Service> = values let services: Vec<Service> = values
.into_iter() .into_iter()
.filter(|v| !v.is_empty()) .filter(|v| !v.is_empty())
.filter_map(|v| match serde_json::from_slice::<Service>(&v) { .filter_map(|v| match serde_json::from_slice::<Service>(&v) {
Ok(service) => Some(service), Ok(service) => Some(service),
Err(e) => { Err(e) => {
tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e); tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
None None
} }
}) })
.collect(); .collect();
tracing::debug!("services: {:?}", services); tracing::debug!("services: {:?}", services);
let endpoints: Vec<Endpoint> = services let endpoints: Vec<Endpoint> = services
.into_iter() .into_iter()
.flat_map(|s| s.endpoints) .flat_map(|s| s.endpoints)
.filter(|s| s.data.is_some()) .filter(|s| s.data.is_some())
.map(|s| Endpoint { .map(|s| Endpoint {
name: s.name, name: s.name,
subject: s.subject, subject: s.subject,
data: s.data.unwrap(), data: s.data.unwrap(),
}) })
.collect(); .collect();
tracing::debug!("endpoints: {:?}", endpoints); tracing::debug!("endpoints: {:?}", endpoints);
tracing::trace!( tracing::trace!(
"found {} endpoints for service: {}", "found {} endpoints for service: {}",
endpoints.len(), endpoints.len(),
service_name service_name
); );
let processed = ProcessedEndpoints::new(endpoints); let processed = ProcessedEndpoints::new(endpoints);
if ep_tx.send(processed).await.is_err() { if ep_tx.send(processed).await.is_err() {
tracing::trace!("failed to send processed endpoints; shutting down"); tracing::trace!("failed to send processed endpoints; shutting down");
break; break;
}
}
} }
} }
} }
...@@ -21,6 +21,12 @@ pub struct ForwardPassMetrics { ...@@ -21,6 +21,12 @@ pub struct ForwardPassMetrics {
pub request_total_slots: u64, pub request_total_slots: u64,
pub kv_active_blocks: u64, pub kv_active_blocks: u64,
pub kv_total_blocks: u64, pub kv_total_blocks: u64,
// integer from 0 to large number
pub num_requests_waiting: u64,
// percentage represented as a float from 0 to 1
pub gpu_cache_usage_perc: f32,
// percentage represented as a float from 0 to 1
pub gpu_prefix_cache_hit_rate: f32,
} }
/// A [`BlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional /// A [`BlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment