Commit 3f84cdad authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: add new metrics and simple router cost fn (#88)

parent 2153ee81
......@@ -112,11 +112,17 @@ fn mock_stats_handler(_stats: Stats) -> serde_json::Value {
let request_active_slots = rand::thread_rng().gen_range(0..=request_total_slots);
let kv_total_blocks = 100;
let kv_active_blocks = rand::thread_rng().gen_range(0..=kv_total_blocks);
let num_requests_waiting = rand::thread_rng().gen_range(0..=100);
let gpu_cache_usage_perc = rand::thread_rng().gen_range(0.0..=1.0);
let gpu_prefix_cache_hit_rate = rand::thread_rng().gen_range(0.0..=1.0);
let stats = ForwardPassMetrics {
request_active_slots,
request_total_slots,
kv_active_blocks,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
};
println!("stats out: {:?}", stats);
serde_json::to_value(stats).unwrap()
......
......@@ -2990,7 +2990,7 @@ index 3cf1850e..ae006579 100644
+ gpu_cache_usage_perc: float
+ gpu_prefix_cache_hit_rate: float
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..fe719642 100644
index 85b5f31e..05030292 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
......@@ -3084,7 +3084,7 @@ index 85b5f31e..fe719642 100644
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
# Pipeline parallel not yet supported
@@ -180,6 +210,63 @@ class MQLLMEngineClient(EngineClient):
@@ -180,6 +210,61 @@ class MQLLMEngineClient(EngineClient):
except Exception as e:
self._set_errored(e)
......@@ -3118,21 +3118,19 @@ index 85b5f31e..fe719642 100644
+ # Metrics received- check the message
+ message: Frame = await self.metrics_socket.recv(copy=False)
+ metrics = pickle.loads(message.buffer)
+ if self.metrics_publisher is not None:
+ if isinstance(metrics, KvMetrics):
+ self.metrics_publisher.publish(metrics.request_active_slots,
+ metrics.request_total_slots,
+ metrics.kv_active_blocks,
+ metrics.kv_total_blocks,
+ metrics.num_requests_waiting,
+ metrics.gpu_cache_usage_perc,
+ metrics.gpu_prefix_cache_hit_rate)
+ if isinstance(metrics, Stats):
+ # TODO
+ # Send the whole stats to user
+ pass
+
+ logger.debug("Metrics successful.")
+ if self.metrics_publisher is not None and isinstance(
+ metrics, KvMetrics
+ ):
+ self.metrics_publisher.publish(metrics.request_active_slots,
+ metrics.request_total_slots,
+ metrics.kv_active_blocks,
+ metrics.kv_total_blocks,
+ metrics.num_requests_waiting,
+ metrics.gpu_cache_usage_perc,
+ metrics.gpu_prefix_cache_hit_rate)
+ logger.debug("Metrics successful.")
+
+ # TODO: Investigate sending whole stats object
+
+ except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
......@@ -3148,7 +3146,7 @@ index 85b5f31e..fe719642 100644
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
@@ -278,12 +365,26 @@ class MQLLMEngineClient(EngineClient):
@@ -278,12 +363,26 @@ class MQLLMEngineClient(EngineClient):
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
......@@ -3175,7 +3173,7 @@ index 85b5f31e..fe719642 100644
def close(self):
"""Destroy the ZeroMQ Context."""
@@ -293,6 +394,8 @@ class MQLLMEngineClient(EngineClient):
@@ -293,6 +392,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
......@@ -3184,7 +3182,7 @@ index 85b5f31e..fe719642 100644
if self.output_loop is not None:
self.output_loop.cancel()
@@ -415,6 +518,9 @@ class MQLLMEngineClient(EngineClient):
@@ -415,6 +516,9 @@ class MQLLMEngineClient(EngineClient):
"""
if self._errored_with is not None:
raise self._errored_with
......@@ -3194,7 +3192,7 @@ index 85b5f31e..fe719642 100644
@property
def is_running(self) -> bool:
@@ -473,6 +579,7 @@ class MQLLMEngineClient(EngineClient):
@@ -473,6 +577,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
......@@ -3202,7 +3200,7 @@ index 85b5f31e..fe719642 100644
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
@@ -502,7 +609,8 @@ class MQLLMEngineClient(EngineClient):
@@ -502,7 +607,8 @@ class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
......@@ -3212,7 +3210,7 @@ index 85b5f31e..fe719642 100644
@overload
def encode(
@@ -586,6 +694,7 @@ class MQLLMEngineClient(EngineClient):
@@ -586,6 +692,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
......@@ -3220,7 +3218,7 @@ index 85b5f31e..fe719642 100644
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -630,6 +739,12 @@ class MQLLMEngineClient(EngineClient):
@@ -630,6 +737,12 @@ class MQLLMEngineClient(EngineClient):
else:
lp_bytes = None
......@@ -3233,7 +3231,7 @@ index 85b5f31e..fe719642 100644
request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
@@ -639,11 +754,11 @@ class MQLLMEngineClient(EngineClient):
@@ -639,11 +752,11 @@ class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
......@@ -3247,7 +3245,7 @@ index 85b5f31e..fe719642 100644
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
@@ -705,3 +820,6 @@ class MQLLMEngineClient(EngineClient):
@@ -705,3 +818,6 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
......@@ -3255,7 +3253,7 @@ index 85b5f31e..fe719642 100644
+ def set_metrics_publisher(self, metrics_publisher):
+ self.metrics_publisher = metrics_publisher
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
index a0dd7958..3204cfb8 100644
index a0dd7958..c82bc15b 100644
--- a/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py
@@ -3,35 +3,115 @@
......@@ -3355,23 +3353,23 @@ index a0dd7958..3204cfb8 100644
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
+# TODO: Send entire stats object to the client
+class StatLogger(StatLoggerBase):
+ def __init__(
+ self,
+ metrics_socket
+ ):
+ self.metrics_socket = metrics_socket
+# class StatLogger(StatLoggerBase):
+# def __init__(
+# self,
+# metrics_socket
+# ):
+# self.metrics_socket = metrics_socket
+
+ def log(self, stats: Stats) -> None:
+ self._send_metrics(stats)
+# def log(self, stats: Stats) -> None:
+# self._send_metrics(stats)
+
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass
+# def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+# pass
+
+ def _send_metrics(self, stats: Stats):
+ if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(stats)
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+# def _send_metrics(self, stats: Stats):
+# if not self.metrics_socket.closed:
+# metrics_bytes = pickle.dumps(stats)
+# self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
+
+
......@@ -3379,7 +3377,7 @@ index a0dd7958..3204cfb8 100644
class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`.
@@ -94,12 +174,35 @@ class MQLLMEngine:
@@ -94,12 +174,37 @@ class MQLLMEngine:
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
......@@ -3406,16 +3404,18 @@ index a0dd7958..3204cfb8 100644
+ self.engine.cache_config.num_gpu_blocks,
+ self.metrics_socket
+ )
+ self.general_stat_logger = StatLogger(
+ self.metrics_socket
+ )
+ self.engine.add_logger("kv_metrics", self.kv_stat_logger)
+ self.engine.add_logger("general_metrics", self.general_stat_logger)
+
+ # TODO investigate sending whole stats object
+ # self.general_stat_logger = StatLogger(
+ # self.metrics_socket
+ # )
+ # self.engine.add_logger("general_metrics", self.general_stat_logger)
+
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
@@ -171,8 +274,17 @@ class MQLLMEngine:
@@ -171,8 +276,17 @@ class MQLLMEngine:
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
......@@ -3435,7 +3435,7 @@ index a0dd7958..3204cfb8 100644
except Exception as e:
response = e
@@ -185,6 +297,7 @@ class MQLLMEngine:
@@ -185,6 +299,7 @@ class MQLLMEngine:
while True:
if not self.engine.has_unfinished_requests():
......@@ -3443,7 +3443,7 @@ index a0dd7958..3204cfb8 100644
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
@@ -220,6 +333,13 @@ class MQLLMEngine:
@@ -220,6 +335,13 @@ class MQLLMEngine:
def handle_new_input(self):
"""Handle new input from the socket"""
try:
......@@ -3457,7 +3457,7 @@ index a0dd7958..3204cfb8 100644
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
@@ -262,6 +382,11 @@ class MQLLMEngine:
@@ -262,6 +384,11 @@ class MQLLMEngine:
self._send_outputs(rpc_err)
try:
......@@ -3469,7 +3469,7 @@ index a0dd7958..3204cfb8 100644
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
@@ -269,7 +394,9 @@ class MQLLMEngine:
@@ -269,7 +396,9 @@ class MQLLMEngine:
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
......
......@@ -237,7 +237,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>]
Example:
```bash
# Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model
bash /workspace/examples/python_rs/llm/vllm/scripts/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B
bash /workspace/examples/python_rs/llm/vllm/scripts/kv-router-run.sh 8 test deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# List tmux sessions
tmux ls
......
......@@ -15,69 +15,20 @@
import asyncio
import random
from argparse import Namespace
from enum import Enum
from typing import AsyncIterator
import uvloop
from common.protocol import Tokens
from vllm.logger import logger as vllm_logger
from dynamo.llm import KvIndexer, KvMetricsAggregator, KvRouter
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
WorkerId = str
class RoutingStrategy(Enum):
PREFIX = "prefix"
ROUND_ROBIN = "round_robin"
RANDOM = "random"
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
router: KvRouter,
routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX,
):
vllm_logger.info(
f"Initializing KV Router with strategy: {routing_strategy.value}"
)
self.router = router
self.routing_strategy = routing_strategy
@dynamo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = None
if self.routing_strategy == RoutingStrategy.PREFIX:
try:
worker_id = await self.router.schedule(request.tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
yield str(worker_id)
else:
# TODO: Do we implement round_robin and random here?
# or just skip this router and directly enable in preprocess?
raise NotImplementedError(
f"Routing strategy {self.routing_strategy} not implemented"
)
class CustomRouter:
"""
Request handler for the generate endpoint
......@@ -85,28 +36,117 @@ class CustomRouter:
def __init__(
self,
workers_client,
indexer: KvIndexer,
metrics_aggregator: KvMetricsAggregator,
):
vllm_logger.info("Initializing Custom Router")
self.indexer = indexer
self.metrics_aggregator = metrics_aggregator
self.workers_client = workers_client
def _cost_function(self, scores, metrics):
# naive cost function for demonstration purposes
current_best = ("", 0)
for worker_id, score in scores.scores.items():
if score > current_best[1]:
current_best = (worker_id, score)
for endpoint in metrics.endpoints:
if endpoint.worker_id == current_best[0]:
print(f"Metrics of endpoint: {endpoint.worker_id}")
print(
f"request slot usage: {endpoint.request_active_slots} / {endpoint.request_total_slots}"
)
print(
f"KV block usage: {endpoint.kv_active_blocks} / {endpoint.kv_total_blocks}"
def _cost_function(
self,
scores: OverlapScores | None,
metrics: AggregatedMetrics | None,
token_length: int,
):
worker_scores = {}
if scores:
for worker_id, score in scores.scores.items():
# score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better
worker_scores[worker_id] = (
score * self.indexer.block_size() / token_length
)
return current_best[0]
worker_metrics = {}
# pull metrics for each worker
max_waiting = 0.0
if metrics:
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
worker_metrics[worker_id] = {
"gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
if hasattr(endpoint, "gpu_cache_usage_perc")
else 0.0,
"num_requests_waiting": endpoint.num_requests_waiting
if hasattr(endpoint, "num_requests_waiting")
else 0.0,
"gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
else 0.0,
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
)
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.endpoint_ids()
worker_logits = {}
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(
worker_id,
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
if max_waiting > 0
else 0.0
)
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = (
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
)
vllm_logger.info(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
)
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return ""
# Select the worker with the highest logit
if worker_logits:
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
else:
best_worker_id = ""
# Log the metrics for the selected worker
if best_worker_id:
vllm_logger.info(
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
)
vllm_logger.info(
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
metrics_dict = worker_metrics.get(best_worker_id, {})
vllm_logger.info(
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
)
vllm_logger.info(
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
)
vllm_logger.info(
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
)
return best_worker_id
@dynamo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
......@@ -116,18 +156,16 @@ class CustomRouter:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
metrics = await self.metrics_aggregator.get_metrics()
worker_id = self._cost_function(scores, metrics)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}")
scores = {}
vllm_logger.exception(f"Error finding matches: {e}")
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics()
worker_id = self._cost_function(scores, metrics, token_length)
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
vllm_logger.info("########")
yield str(worker_id)
......@@ -144,14 +182,6 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
.endpoint("generate")
.client()
)
wait_task = workers_client.wait_for_endpoints()
await asyncio.sleep(1)
while not wait_task.done():
vllm_logger.info("Waiting for workers to be ready...")
await asyncio.sleep(5)
wait_task.result()
while len(workers_client.endpoint_ids()) < args.min_workers:
vllm_logger.info(
......@@ -172,23 +202,11 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
endpoint = router_component.endpoint("generate")
if args.custom_router:
# @REVIEWER - I'm not currently checking if block size matches that of the engine
# If they don't match things will silently fail
# The preferred solution would be for the KV Indexer to read from the MDC in etcd and not bother the user at all
# The second solution would be to do KvIndexer(kv_listener, MDC.block_size)
# as this ensures block size matches that of the engine
# In this case we need to do some sort of handshake or check in case a user just puts in a random block size
indexer = KvIndexer(kv_listener, args.block_size)
metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint(
CustomRouter(indexer, metrics_aggregator).generate
)
else:
# TODO Read block_size from MDC
router = KvRouter(runtime, kv_listener, args.block_size)
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
indexer = KvIndexer(kv_listener, args.block_size)
metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint(
CustomRouter(workers_client, indexer, metrics_aggregator).generate
)
if __name__ == "__main__":
......@@ -197,35 +215,18 @@ if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--routing-strategy",
type=RoutingStrategy,
default=RoutingStrategy.PREFIX,
choices=list(RoutingStrategy),
help="Routing strategy to use",
)
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
# TODO: Read block size
parser.add_argument(
"--block-size",
type=int,
help="KV block size",
)
parser.add_argument(
"--custom-router",
type=bool,
default=False,
help="Whether to use custom router or not",
default=64,
help="Block size for the KV Indexer",
)
args = parser.parse_args()
......
......@@ -99,10 +99,13 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
metrics_publisher.publish(
0,
1024,
0,
1024,
0, # request_active_slots
1024, # request_total_slots
0, # kv_active_blocks
1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
)
await asyncio.gather(
......
......@@ -18,53 +18,78 @@
# - Must use a single GPU for workers as CUDA_VISIBLE_DEVICES is set to a fixed value
# - Must use a single node
if [ $# -lt 2 ]; then
echo "Usage: $0 <number_of_workers> <routing_strategy> [model_name] [endpoint_name]"
echo "Error: Must specify at least number of workers and routing strategy"
if [ $# -lt 3 ]; then
echo "Usage: $0 <number_of_workers> <log_dir_name> [model_name] [model_args] [chat_endpoint_name] [completions_endpoint_name]"
echo "Error: Must specify at least number of workers, log_dir_name"
echo "Optional: model_name (default: deepseek-ai/DeepSeek-R1-Distill-Llama-8B)"
echo "Optional: endpoint_name (default: dynamo.process.chat/completions)"
echo "Optional: model_args (quoted string with model arguments)"
echo "Optional: chat_endpoint_name (default: dynamo.process.chat/completions)"
echo "Optional: completions_endpoint_name (default: dynamo.process.completions)"
exit 1
fi
# Uncomment if using Cache
# export HF_HUB_OFFLINE=1
# https://github.com/vllm-project/vllm/issues/10734#issuecomment-2507201353
# Fix for:torch.distributed.DistBackendError: File name too long
# export GLOO_SOCKET_IFNAME=lo
NUM_WORKERS=$1
ROUTING_STRATEGY=$2
LOG_DIR_NAME=$2
MODEL_NAME=${3:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
ENDPOINT_NAME=${4:-"dynamo.process.chat/completions"}
VALID_STRATEGIES=("prefix")
CUSTOM_MODEL_ARGS=$4
CHAT_ENDPOINT_NAME=${5:-"dynamo.process.chat/completions"}
COMPLETIONS_ENDPOINT_NAME=${6:-"dynamo.process.completions"}
SESSION_NAME="v"
WORKDIR="/workspace/examples/python_rs/llm/vllm"
INIT_CMD="cd $WORKDIR"
if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then
echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}"
exit 1
# Default model args
DEFAULT_MODEL_ARGS="--model $MODEL_NAME \
--tokenizer $MODEL_NAME \
--enable-prefix-caching \
--block-size 64"
# Use custom model args if provided, otherwise use default
if [ -n "$CUSTOM_MODEL_ARGS" ]; then
MODEL_ARGS="$CUSTOM_MODEL_ARGS"
echo "Using custom model arguments"
else
MODEL_ARGS="$DEFAULT_MODEL_ARGS"
echo "Using default model arguments"
fi
# Create logs directory if it doesn't exist
LOGS_DIR="/logs/$LOG_DIR_NAME"
mkdir -p $LOGS_DIR
chmod -R 775 $LOGS_DIR
########################################################
# HTTP Server
########################################################
HTTP_CMD="DYN_LOG=DEBUG http"
HTTP_CMD="DYN_LOG=DEBUG http |& tee $LOGS_DIR/http.log"
tmux new-session -d -s "$SESSION_NAME-http"
tmux send-keys -t "$SESSION_NAME-http" "$INIT_CMD && $HTTP_CMD" C-m
########################################################
# LLMCTL
########################################################
LLMCTL_CMD="sleep 5 && llmctl http remove chat-model $MODEL_NAME && \
llmctl http add chat-model $MODEL_NAME $ENDPOINT_NAME && \
llmctl http list chat-model"
LLMCTL_CMD="sleep 5 && \
llmctl http remove chat $MODEL_NAME && \
llmctl http remove completions $MODEL_NAME && \
llmctl http add chat $MODEL_NAME $CHAT_ENDPOINT_NAME && \
llmctl http add completions $MODEL_NAME $COMPLETIONS_ENDPOINT_NAME && \
llmctl http list |& tee $LOGS_DIR/llmctl.log"
tmux new-session -d -s "$SESSION_NAME-llmctl"
tmux send-keys -t "$SESSION_NAME-llmctl" "$INIT_CMD && $LLMCTL_CMD" C-m
########################################################
# Processor
########################################################
# For now processor gets same args as worker, need to have them communicate over etcd
PROCESSOR_CMD="RUST_LOG=info python3 -m kv_router.processor \
--model $MODEL_NAME \
--tokenizer $MODEL_NAME \
--enable-prefix-caching \
--block-size 32 \
--max-model-len 16384 "
PROCESSOR_CMD="RUST_LOG=info python3 -m kv_router.processor $MODEL_ARGS |& tee $LOGS_DIR/processor.log"
tmux new-session -d -s "$SESSION_NAME-processor"
tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
......@@ -72,10 +97,7 @@ tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
# Router
########################################################
ROUTER_CMD="RUST_LOG=info python3 -m kv_router.router \
--model $MODEL_NAME \
--routing-strategy $ROUTING_STRATEGY \
--min-workers $NUM_WORKERS \
--block-size 32"
--min-workers $NUM_WORKERS |& tee $LOGS_DIR/router.log"
tmux new-session -d -s "$SESSION_NAME-router"
tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m
......@@ -83,17 +105,12 @@ tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m
########################################################
# Workers
########################################################
WORKER_CMD="RUST_LOG=info python3 -m kv_router.worker \
--model $MODEL_NAME \
--tokenizer $MODEL_NAME \
--enable-prefix-caching \
--block-size 64 \
--max-model-len 16384 "
WORKER_CMD="RUST_LOG=info python3 -m kv_router.worker $MODEL_ARGS"
for i in $(seq 1 $NUM_WORKERS); do
tmux new-session -d -s "$SESSION_NAME-$i"
tmux new-session -d -s "$SESSION_NAME-$i"
done
for i in $(seq 1 $NUM_WORKERS); do
tmux send-keys -t "$SESSION_NAME-$i" "$INIT_CMD && CUDA_VISIBLE_DEVICES=$((i-1)) $WORKER_CMD" C-m
done
tmux send-keys -t "$SESSION_NAME-$i" "$INIT_CMD && CUDA_VISIBLE_DEVICES=$((i-1)) $WORKER_CMD |& tee $LOGS_DIR/worker-$i.log" C-m
done
\ No newline at end of file
......@@ -228,6 +228,10 @@ async fn add_model(
endpoint_name
);
if model_name.starts_with('/') {
raise!("Model name '{}' cannot start with a slash", model_name);
}
let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 {
......
......@@ -90,6 +90,7 @@ impl KvMetricsPublisher {
})
}
#[allow(clippy::too_many_arguments)]
fn publish(
&self,
_py: Python,
......@@ -97,6 +98,9 @@ impl KvMetricsPublisher {
request_total_slots: u64,
kv_active_blocks: u64,
kv_total_blocks: u64,
num_requests_waiting: u64,
gpu_cache_usage_perc: f32,
gpu_prefix_cache_hit_rate: f32,
) -> PyResult<()> {
self.inner
.publish(
......@@ -105,6 +109,9 @@ impl KvMetricsPublisher {
request_total_slots,
kv_active_blocks,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
}
.into(),
)
......@@ -180,6 +187,10 @@ impl KvIndexer {
})
}
fn block_size(&self) -> usize {
self.inner.block_size()
}
fn find_matches_for_request<'p>(
&self,
py: Python<'p>,
......@@ -212,6 +223,12 @@ pub(crate) struct EndpointKvMetrics {
pub kv_active_blocks: u64,
#[pyo3(get, set)]
pub kv_total_blocks: u64,
#[pyo3(get, set)]
pub num_requests_waiting: u64,
#[pyo3(get, set)]
pub gpu_cache_usage_perc: f32,
#[pyo3(get, set)]
pub gpu_prefix_cache_hit_rate: f32,
}
#[pyclass]
......@@ -258,6 +275,9 @@ impl KvMetricsAggregator {
request_total_slots: x.data.request_total_slots,
kv_active_blocks: x.data.kv_active_blocks,
kv_total_blocks: x.data.kv_total_blocks,
num_requests_waiting: x.data.num_requests_waiting,
gpu_cache_usage_perc: x.data.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: x.data.gpu_prefix_cache_hit_rate,
})
.collect();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
......
......@@ -242,7 +242,11 @@ class KvMetricsPublisher:
def publish(self, request_active_slots: int,
request_total_slots: int,
kv_active_blocks: int,
kv_total_blocks: int) -> None:
kv_total_blocks: int,
num_requests_waiting: int,
gpu_cache_usage_perc: float,
gpu_prefix_cache_hit_rate: float
) -> None:
"""
Update the KV metrics being reported.
"""
......@@ -298,7 +302,7 @@ class KvIndexer:
...
def __init__(self, component: Component) -> None:
def __init__(self, component: Component, block_size: int) -> None:
"""
Create a `KvIndexer` object
"""
......@@ -309,6 +313,12 @@ class KvIndexer:
"""
...
def block_size(self) -> int:
"""
Return the block size of the KV Indexer.
"""
...
class AggregatedMetrics:
"""
A collection of metrics of the endpoints
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dynamo._core import AggregatedMetrics as AggregatedMetrics
from dynamo._core import DisaggregatedRouter as DisaggregatedRouter
from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpError as HttpError
......@@ -21,3 +22,4 @@ from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvMetricsPublisher as KvMetricsPublisher
from dynamo._core import KvRouter as KvRouter
from dynamo._core import OverlapScores as OverlapScores
......@@ -193,6 +193,9 @@ async def test_metrics_aggregator(distributed_runtime):
"request_total_slots": 1024,
"kv_active_blocks": 523,
"kv_total_blocks": 777,
"num_requests_waiting": 10,
"gpu_cache_usage_perc": 0.5,
"gpu_prefix_cache_hit_rate": 0.75,
}
# need 'create_task' to put publisher task in the background
......@@ -222,5 +225,8 @@ async def metrics_publisher_task(kv_listener, expected_metrics):
expected_metrics["request_total_slots"],
expected_metrics["kv_active_blocks"],
expected_metrics["kv_total_blocks"],
expected_metrics["num_requests_waiting"],
expected_metrics["gpu_cache_usage_perc"],
expected_metrics["gpu_prefix_cache_hit_rate"],
)
await metrics_publisher.create_endpoint(kv_listener)
......@@ -16,7 +16,7 @@
use anyhow::Result;
use dynamo_runtime::{component::Component, component::Namespace, DistributedRuntime};
use futures::stream::StreamExt;
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing;
......@@ -29,7 +29,8 @@ pub mod scoring;
use crate::kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
scheduler::{Endpoint, KvScheduler, Service},
metrics_aggregator::collect_endpoints,
scheduler::KvScheduler,
scoring::ProcessedEndpoints,
};
......@@ -148,73 +149,3 @@ impl KvRouter {
Ok(worker_id)
}
}
async fn collect_endpoints(
nats_client: dynamo_runtime::transports::nats::Client,
service_name: String,
ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>,
cancel: CancellationToken,
) {
loop {
tokio::select! {
_ = cancel.cancelled() => {
tracing::debug!("cancellation token triggered");
break;
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
tracing::trace!("collecting endpoints for service: {}", service_name);
}
}
let values = match nats_client
.get_endpoints(&service_name, Duration::from_secs(1))
.await
{
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
continue;
}
};
tracing::debug!("values: {:?}", values);
let services: Vec<Service> = values
.into_iter()
.filter(|v| !v.is_empty())
.filter_map(|v| match serde_json::from_slice::<Service>(&v) {
Ok(service) => Some(service),
Err(e) => {
tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
None
}
})
.collect();
tracing::debug!("services: {:?}", services);
let endpoints: Vec<Endpoint> = services
.into_iter()
.flat_map(|s| s.endpoints)
.filter(|s| s.data.is_some())
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: s.data.unwrap(),
})
.collect();
tracing::debug!("endpoints: {:?}", endpoints);
tracing::trace!(
"found {} endpoints for service: {}",
endpoints.len(),
service_name
);
let processed = ProcessedEndpoints::new(endpoints);
// process endpoints into
if ep_tx.send(processed).await.is_err() {
tracing::trace!("failed to send processed endpoints; shutting down");
break;
}
}
}
......@@ -588,6 +588,10 @@ impl KvIndexer {
}
}
pub fn block_size(&self) -> usize {
self.kv_block_size
}
pub fn new(token: CancellationToken, kv_block_size: usize) -> Self {
Self::new_with_frequency(token, None, kv_block_size)
}
......@@ -775,6 +779,10 @@ impl KvIndexerSharded {
}
}
pub fn block_size(&self) -> usize {
self.kv_block_size
}
pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self {
Self::new_with_frequency(token, num_shards, None, kv_block_size)
}
......
......@@ -80,70 +80,71 @@ impl KvMetricsAggregator {
}
}
async fn collect_endpoints(
pub async fn collect_endpoints(
nats_client: dynamo_runtime::transports::nats::Client,
service_name: String,
ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>,
cancel: CancellationToken,
) {
let backoff_delay = Duration::from_millis(100);
loop {
tokio::select! {
_ = cancel.cancelled() => {
tracing::debug!("cancellation token triggered");
break;
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
_ = tokio::time::sleep(backoff_delay) => {
tracing::trace!("collecting endpoints for service: {}", service_name);
}
}
let values = match nats_client
.get_endpoints(&service_name, Duration::from_secs(1))
.await
{
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
continue;
}
};
let values = match nats_client
.get_endpoints(&service_name, Duration::from_millis(300))
.await
{
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
continue;
}
};
tracing::debug!("values: {:?}", values);
let services: Vec<Service> = values
.into_iter()
.filter(|v| !v.is_empty())
.filter_map(|v| match serde_json::from_slice::<Service>(&v) {
Ok(service) => Some(service),
Err(e) => {
tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
None
}
})
.collect();
tracing::debug!("services: {:?}", services);
tracing::debug!("values: {:?}", values);
let services: Vec<Service> = values
.into_iter()
.filter(|v| !v.is_empty())
.filter_map(|v| match serde_json::from_slice::<Service>(&v) {
Ok(service) => Some(service),
Err(e) => {
tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
None
}
})
.collect();
tracing::debug!("services: {:?}", services);
let endpoints: Vec<Endpoint> = services
.into_iter()
.flat_map(|s| s.endpoints)
.filter(|s| s.data.is_some())
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: s.data.unwrap(),
})
.collect();
tracing::debug!("endpoints: {:?}", endpoints);
let endpoints: Vec<Endpoint> = services
.into_iter()
.flat_map(|s| s.endpoints)
.filter(|s| s.data.is_some())
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: s.data.unwrap(),
})
.collect();
tracing::debug!("endpoints: {:?}", endpoints);
tracing::trace!(
"found {} endpoints for service: {}",
endpoints.len(),
service_name
);
tracing::trace!(
"found {} endpoints for service: {}",
endpoints.len(),
service_name
);
let processed = ProcessedEndpoints::new(endpoints);
if ep_tx.send(processed).await.is_err() {
tracing::trace!("failed to send processed endpoints; shutting down");
break;
let processed = ProcessedEndpoints::new(endpoints);
if ep_tx.send(processed).await.is_err() {
tracing::trace!("failed to send processed endpoints; shutting down");
break;
}
}
}
}
}
......@@ -21,6 +21,12 @@ pub struct ForwardPassMetrics {
pub request_total_slots: u64,
pub kv_active_blocks: u64,
pub kv_total_blocks: u64,
// integer from 0 to large number
pub num_requests_waiting: u64,
// percentage represented as a float from 0 to 1
pub gpu_cache_usage_perc: f32,
// percentage represented as a float from 0 to 1
pub gpu_prefix_cache_hit_rate: f32,
}
/// A [`BlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment