Unverified Commit 84e71e27 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: predictive active blocks for routing without load metrics (#1731)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarAlec <35311602+alec-flowers@users.noreply.github.com>
parent ffccc722
...@@ -83,7 +83,7 @@ use serde::{Deserialize, Serialize}; ...@@ -83,7 +83,7 @@ use serde::{Deserialize, Serialize};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration as StdDuration; use std::time::Duration as StdDuration;
use dynamo_llm::kv_router::protocols::ForwardPassMetrics; use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics};
use dynamo_llm::kv_router::scheduler::Endpoint; use dynamo_llm::kv_router::scheduler::Endpoint;
use dynamo_llm::kv_router::scoring::ProcessedEndpoints; use dynamo_llm::kv_router::scoring::ProcessedEndpoints;
...@@ -449,7 +449,10 @@ impl PrometheusMetrics { ...@@ -449,7 +449,10 @@ impl PrometheusMetrics {
// Update per-worker metrics // Update per-worker metrics
for (worker_id, endpoint) in processed.endpoints.iter() { for (worker_id, endpoint) in processed.endpoints.iter() {
let worker_id = worker_id.to_string(); let worker_id = worker_id.to_string();
let metrics = endpoint.data.clone(); let load_metrics = endpoint.data.clone();
let LoadMetrics::EngineLoadMetrics(metrics) = load_metrics else {
panic!("Can only update with ForwardPassMetrics");
};
self.set_worker_gauge( self.set_worker_gauge(
&self.kv_blocks_active, &self.kv_blocks_active,
...@@ -602,7 +605,7 @@ pub fn postprocess_metrics( ...@@ -602,7 +605,7 @@ pub fn postprocess_metrics(
e.id().ok().map(|id| Endpoint { e.id().ok().map(|id| Endpoint {
name: format!("worker-{id}"), name: format!("worker-{id}"),
subject: e.subject.clone(), subject: e.subject.clone(),
data: m.clone(), data: LoadMetrics::EngineLoadMetrics(m.clone()),
}) })
}) })
.collect(); .collect();
......
...@@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> {
let selector = Box::new(CustomWorkerSelector::default()); let selector = Box::new(CustomWorkerSelector::default());
let router = KvRouter::new(component.clone(), args.block_size, Some(selector)).await?; let router = KvRouter::new(component.clone(), args.block_size, Some(selector), true).await?;
let router = Ingress::for_engine(Arc::new(router))?; let router = Ingress::for_engine(Arc::new(router))?;
component component
......
...@@ -3392,14 +3392,8 @@ index cafd8150b..6a5e45b4e 100644 ...@@ -3392,14 +3392,8 @@ index cafd8150b..6a5e45b4e 100644
+ num_requests_waiting: int + num_requests_waiting: int
+ gpu_cache_usage_perc: float + gpu_cache_usage_perc: float
+ gpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float
+ spec_decode_draft_acceptance_rate: Optional[float] = None
+ spec_decode_system_efficiency: Optional[float] = None
+ spec_decode_draft_tokens: Optional[int] = None
+ spec_decode_emitted_tokens: Optional[int] = None
+ spec_decode_accepted_tokens: Optional[int] = None
+ spec_decode_num_spec_tokens: Optional[int] = None
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index f058b1329..2fdb5b8bf 100644 index f058b1329..fd5610a3c 100644
--- a/vllm/engine/multiprocessing/client.py --- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py
@@ -1,4 +1,17 @@ @@ -1,4 +1,17 @@
...@@ -3460,16 +3454,25 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3460,16 +3454,25 @@ index f058b1329..2fdb5b8bf 100644
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
@@ -48,6 +66,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest @@ -48,6 +66,17 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import Device, deprecate_kwargs from vllm.utils import Device, deprecate_kwargs
+from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback +from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback
+from vllm.distributed.device_communicators.nixl import NixlMetadata +from vllm.distributed.device_communicators.nixl import NixlMetadata
+
+# Import ForwardPassMetrics and related classes from dynamo
+try:
+ from dynamo.llm import ForwardPassMetrics, WorkerStats, KvStats
+except ImportError:
+ # Fallback if dynamo imports are not available
+ ForwardPassMetrics = None
+ WorkerStats = None
+ KvStats = None
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -93,6 +113,7 @@ class MQLLMEngineClient(EngineClient): @@ -93,6 +122,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
# Get the configs. # Get the configs.
...@@ -3477,7 +3480,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3477,7 +3480,7 @@ index f058b1329..2fdb5b8bf 100644
self.model_config = engine_config.model_config self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config self.decoding_config = engine_config.decoding_config
@@ -117,6 +138,10 @@ class MQLLMEngineClient(EngineClient): @@ -117,6 +147,10 @@ class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
...@@ -3488,7 +3491,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3488,7 +3491,7 @@ index f058b1329..2fdb5b8bf 100644
# IPC path for the data socket. # IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -131,8 +156,27 @@ class MQLLMEngineClient(EngineClient): @@ -131,8 +165,27 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically. # Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready. # Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None self.health_loop: Optional[asyncio.Task] = None
...@@ -3516,7 +3519,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3516,7 +3519,7 @@ index f058b1329..2fdb5b8bf 100644
@staticmethod @staticmethod
def is_unsupported_config(vllm_config: VllmConfig): def is_unsupported_config(vllm_config: VllmConfig):
# Pipeline parallel not yet supported # Pipeline parallel not yet supported
@@ -182,6 +226,61 @@ class MQLLMEngineClient(EngineClient): @@ -182,6 +235,76 @@ class MQLLMEngineClient(EngineClient):
except Exception as e: except Exception as e:
self._set_errored(e) self._set_errored(e)
...@@ -3553,13 +3556,28 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3553,13 +3556,28 @@ index f058b1329..2fdb5b8bf 100644
+ if self.metrics_publisher is not None and isinstance( + if self.metrics_publisher is not None and isinstance(
+ metrics, KvMetrics + metrics, KvMetrics
+ ): + ):
+ self.metrics_publisher.publish(metrics.request_active_slots, + # Construct structured metrics objects
+ metrics.request_total_slots, + worker_stats = WorkerStats(
+ metrics.kv_active_blocks, + request_active_slots=metrics.request_active_slots,
+ metrics.kv_total_blocks, + request_total_slots=metrics.request_total_slots,
+ metrics.num_requests_waiting, + num_requests_waiting=metrics.num_requests_waiting,
+ metrics.gpu_cache_usage_perc, + data_parallel_rank=None
+ metrics.gpu_prefix_cache_hit_rate) + )
+
+ kv_stats = KvStats(
+ kv_active_blocks=metrics.kv_active_blocks,
+ kv_total_blocks=metrics.kv_total_blocks,
+ gpu_cache_usage_perc=metrics.gpu_cache_usage_perc,
+ gpu_prefix_cache_hit_rate=metrics.gpu_prefix_cache_hit_rate
+ )
+
+ forward_pass_metrics = ForwardPassMetrics(
+ worker_stats=worker_stats,
+ kv_stats=kv_stats,
+ spec_decode_stats=None
+ )
+
+ self.metrics_publisher.publish(forward_pass_metrics)
+ logger.debug("Metrics successful.") + logger.debug("Metrics successful.")
+ +
+ # TODO: Investigate sending whole stats object + # TODO: Investigate sending whole stats object
...@@ -3578,7 +3596,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3578,7 +3596,7 @@ index f058b1329..2fdb5b8bf 100644
async def run_output_handler_loop(self): async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues""" """Get RequestOutputs from Engine and stream to Request Queues"""
@@ -250,7 +349,7 @@ class MQLLMEngineClient(EngineClient): @@ -250,7 +373,7 @@ class MQLLMEngineClient(EngineClient):
# Put each output into the appropriate queue. # Put each output into the appropriate queue.
elif isinstance( elif isinstance(
request_outputs, request_outputs,
...@@ -3587,7 +3605,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3587,7 +3605,7 @@ index f058b1329..2fdb5b8bf 100644
self._add_output(request_outputs) self._add_output(request_outputs)
else: else:
for request_output in request_outputs: for request_output in request_outputs:
@@ -261,7 +360,7 @@ class MQLLMEngineClient(EngineClient): @@ -261,7 +384,7 @@ class MQLLMEngineClient(EngineClient):
def _add_output(self, request_output: Union[RequestOutput, def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse, RPCAdapterLoadedResponse,
...@@ -3596,7 +3614,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3596,7 +3614,7 @@ index f058b1329..2fdb5b8bf 100644
queue = self.output_queues.get(request_output.request_id) queue = self.output_queues.get(request_output.request_id)
if queue is not None: if queue is not None:
queue.put_nowait(request_output) queue.put_nowait(request_output)
@@ -283,12 +382,25 @@ class MQLLMEngineClient(EngineClient): @@ -283,12 +406,25 @@ class MQLLMEngineClient(EngineClient):
# Wait until server is ready. # Wait until server is ready.
response = await self._wait_for_server_rpc(socket) response = await self._wait_for_server_rpc(socket)
...@@ -3622,7 +3640,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3622,7 +3640,7 @@ index f058b1329..2fdb5b8bf 100644
def close(self): def close(self):
"""Destroy the ZeroMQ Context.""" """Destroy the ZeroMQ Context."""
@@ -298,6 +410,8 @@ class MQLLMEngineClient(EngineClient): @@ -298,6 +434,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks. # Cancel background tasks.
if self.health_loop is not None: if self.health_loop is not None:
self.health_loop.cancel() self.health_loop.cancel()
...@@ -3631,7 +3649,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3631,7 +3649,7 @@ index f058b1329..2fdb5b8bf 100644
if self.output_loop is not None: if self.output_loop is not None:
self.output_loop.cancel() self.output_loop.cancel()
@@ -420,6 +534,9 @@ class MQLLMEngineClient(EngineClient): @@ -420,6 +558,9 @@ class MQLLMEngineClient(EngineClient):
""" """
if self._errored_with is not None: if self._errored_with is not None:
raise self._errored_with raise self._errored_with
...@@ -3641,7 +3659,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3641,7 +3659,7 @@ index f058b1329..2fdb5b8bf 100644
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@@ -478,6 +595,7 @@ class MQLLMEngineClient(EngineClient): @@ -478,6 +619,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
...@@ -3649,7 +3667,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3649,7 +3667,7 @@ index f058b1329..2fdb5b8bf 100644
*, *,
inputs: Optional[PromptType] = None # DEPRECATED inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
@@ -507,7 +625,8 @@ class MQLLMEngineClient(EngineClient): @@ -507,7 +649,8 @@ class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id, return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers,
...@@ -3659,7 +3677,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3659,7 +3677,7 @@ index f058b1329..2fdb5b8bf 100644
@overload @overload
def encode( def encode(
@@ -591,6 +710,7 @@ class MQLLMEngineClient(EngineClient): @@ -591,6 +734,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
...@@ -3667,7 +3685,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3667,7 +3685,7 @@ index f058b1329..2fdb5b8bf 100644
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]: PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -636,6 +756,12 @@ class MQLLMEngineClient(EngineClient): @@ -636,6 +780,12 @@ class MQLLMEngineClient(EngineClient):
else: else:
lp_bytes = None lp_bytes = None
...@@ -3680,7 +3698,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3680,7 +3698,7 @@ index f058b1329..2fdb5b8bf 100644
request_bytes = pickle.dumps( request_bytes = pickle.dumps(
RPCProcessRequest( RPCProcessRequest(
prompt=prompt, prompt=prompt,
@@ -645,11 +771,11 @@ class MQLLMEngineClient(EngineClient): @@ -645,11 +795,11 @@ class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
...@@ -3694,7 +3712,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3694,7 +3712,7 @@ index f058b1329..2fdb5b8bf 100644
await self.input_socket.send_multipart(parts, copy=False) await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note # 4) Stream the RequestOutputs from the output queue. Note
@@ -740,3 +866,22 @@ class MQLLMEngineClient(EngineClient): @@ -740,3 +890,22 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None # Raise on error, otherwise happily return None
if isinstance(request_output, BaseException): if isinstance(request_output, BaseException):
raise request_output raise request_output
......
...@@ -15,6 +15,8 @@ See the License for the specific language governing permissions and ...@@ -15,6 +15,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
--> -->
>[!NOTE]
>This information is temporary and will change soon.
# KV Cache Routing # KV Cache Routing
This documentation explains how Key-Value (KV) cache routing works in Dynamo, providing optimized inference for large language models by intelligently directing requests to workers with the most relevant cached data while simultaneously load balancing based on utilization metrics sent by the workers. This documentation explains how Key-Value (KV) cache routing works in Dynamo, providing optimized inference for large language models by intelligently directing requests to workers with the most relevant cached data while simultaneously load balancing based on utilization metrics sent by the workers.
......
...@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm. ...@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
Usage: Usage:
``` ```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=2.0] [--kv-gpu-cache-usage-weight=1.0] [--kv-waiting-requests-weight=1.0] [--verbosity (-v|-vv)] dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--verbosity (-v|-vv)]
``` ```
Example: `dynamo run Qwen/Qwen3-0.6B` Example: `dynamo run Qwen/Qwen3-0.6B`
...@@ -201,6 +201,8 @@ The only difference from the distributed system above is `--router-mode kv`. The ...@@ -201,6 +201,8 @@ The only difference from the distributed system above is `--router-mode kv`. The
For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing. For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing.
The argument `--kv-overlap-score-weight` sets the amount weighting on overlaps with prefix caches, which directly contributes to the prefill cost, so a large weight is expected to yield a better TTFT (at the expense of worse ITL). When this is set 0, we do not consider the prefix caches at all (falling back to pure load balancing behavior on the active blocks), in which case we do not require the backend engines to emit any KV events. The argument `--router-temperature` sets the temperature when randomly selecting the workers to route to via softmax sampling on the router cost logits, setting it to 0 recovers the deterministic behavior where the min logit is picked.
## Full usage details ## Full usage details
`dynamo run` executes `dynamo-run`. `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features. `dynamo run` executes `dynamo-run`. `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features.
......
...@@ -15,6 +15,9 @@ See the License for the specific language governing permissions and ...@@ -15,6 +15,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
--> -->
>[!NOTE]
>This information is temporary and will change soon.
# KV Router Performance Tuning # KV Router Performance Tuning
## Overview ## Overview
......
...@@ -110,20 +110,23 @@ pub struct Flags { ...@@ -110,20 +110,23 @@ pub struct Flags {
#[arg(long, default_value = "round-robin")] #[arg(long, default_value = "round-robin")]
pub router_mode: RouterMode, pub router_mode: RouterMode,
/// Maximum number of batched tokens for KV routing
/// Needed for informing the KV router
/// TODO: derive from vllm args
/// NOTE: this is not actually used for now
#[arg(long, default_value = "8192")]
pub max_num_batched_tokens: Option<u32>,
/// KV Router: Weight for overlap score in worker selection. /// KV Router: Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 2.0 /// Higher values prioritize KV cache reuse. Default: 2.0
#[arg(long)] #[arg(long)]
pub kv_overlap_score_weight: Option<f64>, pub kv_overlap_score_weight: Option<f64>,
/// KV Router: Weight for GPU cache usage in worker selection. /// KV Router: Temperature for worker sampling via softmax.
/// Higher values avoid workers with nearly full KV caches. Default: 1.0 /// Higher values promote more randomness, and 0 fallbacks to deterministic.
#[arg(long)] /// Default: 0.5
pub kv_gpu_cache_usage_weight: Option<f64>,
/// KV Router: Weight for waiting requests in worker selection.
/// Higher values avoid workers with queued requests. Default: 1.0
#[arg(long)] #[arg(long)]
pub kv_waiting_requests_weight: Option<f64>, pub router_temperature: Option<f64>,
/// Max model context length. Reduce this if you don't have enough VRAM for the full model /// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4). /// context length (e.g. Llama 4).
...@@ -211,8 +214,8 @@ impl Flags { ...@@ -211,8 +214,8 @@ impl Flags {
self.router_mode.into(), self.router_mode.into(),
KvRouterConfig::new( KvRouterConfig::new(
self.kv_overlap_score_weight, self.kv_overlap_score_weight,
self.kv_gpu_cache_usage_weight, self.router_temperature,
self.kv_waiting_requests_weight, self.max_num_batched_tokens,
), ),
) )
} }
......
...@@ -26,7 +26,14 @@ from vllm.entrypoints.openai.api_server import ( ...@@ -26,7 +26,14 @@ from vllm.entrypoints.openai.api_server import (
) )
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from dynamo.llm import ModelType, WorkerMetricsPublisher, register_llm from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelType,
WorkerMetricsPublisher,
WorkerStats,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -70,15 +77,29 @@ class RequestHandler: ...@@ -70,15 +77,29 @@ class RequestHandler:
self.engine_client.set_metrics_publisher(self.metrics_publisher) self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start, # Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered # vLLM will not update stat until forward pass is triggered
self.metrics_publisher.publish(
0, # request_active_slots # Create the structured metrics objects
1024, # request_total_slots worker_stats = WorkerStats(
0, # kv_active_blocks request_active_slots=0,
1024, # kv_total_blocks request_total_slots=1024,
0, # num_requests_waiting num_requests_waiting=0,
0.0, # gpu_cache_usage_perc data_parallel_rank=None,
0.0, # gpu_prefix_cache_hit_rate )
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=1024,
gpu_cache_usage_perc=0.0,
gpu_prefix_cache_hit_rate=0.0,
) )
metrics = ForwardPassMetrics(
worker_stats=worker_stats, kv_stats=kv_stats, spec_decode_stats=None
)
# Publish the metrics as a single object
self.metrics_publisher.publish(metrics)
task = asyncio.create_task(self.create_metrics_publisher_endpoint()) task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback( task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created") lambda _: logging.debug("metrics publisher endpoint created")
......
...@@ -72,7 +72,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -72,7 +72,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Client>()?; m.add_class::<Client>()?;
m.add_class::<EtcdClient>()?; m.add_class::<EtcdClient>()?;
m.add_class::<AsyncResponseStream>()?; m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::disagg_router::DisaggregatedRouter>()?; m.add_class::<llm::disagg_router::DisaggregatedRouter>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?; m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?; m.add_class::<llm::model_card::ModelDeploymentCard>()?;
......
...@@ -29,51 +29,6 @@ use tracing; ...@@ -29,51 +29,6 @@ use tracing;
use llm_rs::kv_router::protocols::*; use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig}; use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig};
#[pyclass]
pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>,
}
#[pymethods]
impl KvRouter {
#[new]
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
};
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner = llm_rs::kv_router::KvRouter::new(
component.inner.clone(),
kv_block_size as u32,
None,
)
.await
.map_err(to_pyerr)?;
Ok(Self {
inner: Arc::new(inner),
})
})
}
fn schedule<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
lora_id: u64,
) -> PyResult<Bound<'p, PyAny>> {
let router = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker_id = router
.schedule(&token_ids, lora_id)
.await
.map_err(to_pyerr)?;
Ok(worker_id)
})
}
}
#[pyfunction] #[pyfunction]
pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> { pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> {
if kv_block_size == 0 { if kv_block_size == 0 {
...@@ -617,25 +572,34 @@ impl KvMetricsAggregator { ...@@ -617,25 +572,34 @@ impl KvMetricsAggregator {
fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> { fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
// TODO: update EndpointKvMetrics to match the new ForwardPassMetrics struct // TODO: update EndpointKvMetrics to match the new ForwardPassMetrics struct
let endpoints = self.inner.get_endpoints(); let endpoints = self.inner.get_endpoints();
let load_avg = endpoints.load_avg;
let load_std = endpoints.load_std;
let endpoint_kv_metrics = endpoints let endpoint_kv_metrics = endpoints
.endpoints .endpoints
.iter() .into_iter()
.map(|(worker_id, endpoint)| EndpointKvMetrics { .map(|(worker_id, endpoint)| {
worker_id: *worker_id, let metrics = endpoint.data;
request_active_slots: endpoint.data.worker_stats.request_active_slots, let LoadMetrics::EngineLoadMetrics(fwd_pass_metrics) = metrics else {
request_total_slots: endpoint.data.worker_stats.request_total_slots, panic!("Endpoints do not contain forward pass metrics.");
kv_active_blocks: endpoint.data.kv_stats.kv_active_blocks, };
kv_total_blocks: endpoint.data.kv_stats.kv_total_blocks, EndpointKvMetrics {
num_requests_waiting: endpoint.data.worker_stats.num_requests_waiting, worker_id,
gpu_cache_usage_perc: endpoint.data.kv_stats.gpu_cache_usage_perc, request_active_slots: fwd_pass_metrics.worker_stats.request_active_slots,
gpu_prefix_cache_hit_rate: endpoint.data.kv_stats.gpu_prefix_cache_hit_rate, request_total_slots: fwd_pass_metrics.worker_stats.request_total_slots,
kv_active_blocks: fwd_pass_metrics.kv_stats.kv_active_blocks,
kv_total_blocks: fwd_pass_metrics.kv_stats.kv_total_blocks,
num_requests_waiting: fwd_pass_metrics.worker_stats.num_requests_waiting,
gpu_cache_usage_perc: fwd_pass_metrics.kv_stats.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: fwd_pass_metrics.kv_stats.gpu_prefix_cache_hit_rate,
}
}) })
.collect(); .collect();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(AggregatedMetrics { Ok(AggregatedMetrics {
endpoints: endpoint_kv_metrics, endpoints: endpoint_kv_metrics,
load_avg: endpoints.load_avg, load_avg,
load_std: endpoints.load_std, load_std,
}) })
}) })
} }
......
...@@ -272,25 +272,6 @@ class Client: ...@@ -272,25 +272,6 @@ class Client:
""" """
... ...
class KvRouter:
"""
A router will determine which worker should handle a given request.
"""
...
def __init__(self, drt: DistributedRuntime, component: Component) -> None:
"""
Create a `KvRouter` object that is associated with the `component`
"""
def schedule(self, token_ids: List[int], lora_id: int) -> int:
"""
Return the worker id that should handle the given token ids,
exception will be raised if there is no worker available.
"""
...
class DisaggregatedRouter: class DisaggregatedRouter:
""" """
A router that determines whether to perform prefill locally or remotely based on A router that determines whether to perform prefill locally or remotely based on
......
...@@ -32,7 +32,6 @@ from dynamo._core import KvEventPublisher as KvEventPublisher ...@@ -32,7 +32,6 @@ from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvRecorder as KvRecorder from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouter as KvRouter
from dynamo._core import KvStats as KvStats from dynamo._core import KvStats as KvStats
from dynamo._core import ModelType as ModelType from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores from dynamo._core import OverlapScores as OverlapScores
......
...@@ -212,8 +212,20 @@ impl ModelManager { ...@@ -212,8 +212,20 @@ impl ModelManager {
kv_cache_block_size: u32, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> { ) -> anyhow::Result<Arc<KvRouter>> {
// Determine if we should use KV events based on overlap score weight
let use_kv_events = kv_router_config
.as_ref()
.map(|config| config.overlap_score_weight > 0.0)
.unwrap_or(false);
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new(component.clone(), kv_cache_block_size, Some(selector)).await?; let chooser = KvRouter::new(
component.clone(),
kv_cache_block_size,
Some(selector),
use_kv_events,
)
.await?;
let new_kv_chooser = Arc::new(chooser); let new_kv_chooser = Arc::new(chooser);
self.kv_choosers self.kv_choosers
.lock() .lock()
......
...@@ -23,11 +23,12 @@ pub mod publisher; ...@@ -23,11 +23,12 @@ pub mod publisher;
pub mod recorder; pub mod recorder;
pub mod scheduler; pub mod scheduler;
pub mod scoring; pub mod scoring;
pub mod sequence;
use crate::{ use crate::{
kv_router::{ kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
metrics_aggregator::KvMetricsAggregator, metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
...@@ -58,25 +59,20 @@ pub trait WorkerSelector { ...@@ -58,25 +59,20 @@ pub trait WorkerSelector {
/// KV Router configuration parameters /// KV Router configuration parameters
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct KvRouterConfig { pub struct KvRouterConfig {
/// Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 2.0
pub overlap_score_weight: f64, pub overlap_score_weight: f64,
/// Weight for GPU cache usage in worker selection. pub router_temperature: f64,
/// Higher values avoid workers with nearly full KV caches. Default: 1.0
pub gpu_cache_usage_weight: f64,
/// Weight for waiting requests in worker selection. // note: this is not actually used for now
/// Higher values avoid workers with queued requests. Default: 1.0 pub max_num_batched_tokens: u32,
pub waiting_requests_weight: f64,
} }
impl Default for KvRouterConfig { impl Default for KvRouterConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
overlap_score_weight: 1.0, overlap_score_weight: 1.0,
gpu_cache_usage_weight: 1.0, router_temperature: 0.5,
waiting_requests_weight: 1.0, max_num_batched_tokens: 8192,
} }
} }
} }
...@@ -86,16 +82,15 @@ impl KvRouterConfig { ...@@ -86,16 +82,15 @@ impl KvRouterConfig {
/// If a weight is None, the default value will be used. /// If a weight is None, the default value will be used.
pub fn new( pub fn new(
overlap_score_weight: Option<f64>, overlap_score_weight: Option<f64>,
gpu_cache_usage_weight: Option<f64>, temperature: Option<f64>,
waiting_requests_weight: Option<f64>, max_num_batched_tokens: Option<u32>,
) -> Self { ) -> Self {
let default = Self::default(); let default = Self::default();
Self { Self {
overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight), overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
gpu_cache_usage_weight: gpu_cache_usage_weight router_temperature: temperature.unwrap_or(default.router_temperature),
.unwrap_or(default.gpu_cache_usage_weight), max_num_batched_tokens: max_num_batched_tokens
waiting_requests_weight: waiting_requests_weight .unwrap_or(default.max_num_batched_tokens),
.unwrap_or(default.waiting_requests_weight),
} }
} }
} }
...@@ -103,7 +98,7 @@ impl KvRouterConfig { ...@@ -103,7 +98,7 @@ impl KvRouterConfig {
/// A KvRouter only decides which worker you should use. It doesn't send you there. /// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route. /// TODO: Rename this to indicate it only selects a worker, it does not route.
pub struct KvRouter { pub struct KvRouter {
indexer: KvIndexer, indexer: Option<KvIndexer>,
scheduler: KvScheduler, scheduler: KvScheduler,
block_size: u32, block_size: u32,
} }
...@@ -113,16 +108,19 @@ impl KvRouter { ...@@ -113,16 +108,19 @@ impl KvRouter {
component: Component, component: Component,
block_size: u32, block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
use_kv_events: bool,
) -> Result<Self> { ) -> Result<Self> {
let cancellation_token = component let cancellation_token = component
.drt() .drt()
.primary_lease() .primary_lease()
.expect("Cannot KV route static workers") .expect("Cannot KV route static workers")
.primary_token(); .primary_token();
tracing::info!("KV Routing initialized");
let metrics_aggregator = let metrics_aggregator =
KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await; EndpointCollector::new(component.clone(), cancellation_token.clone()).await;
let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
let maybe_indexer =
use_kv_events.then(|| KvIndexer::new(cancellation_token.clone(), block_size));
let scheduler = KvScheduler::start( let scheduler = KvScheduler::start(
component.namespace().clone(), component.namespace().clone(),
block_size, block_size,
...@@ -133,50 +131,47 @@ impl KvRouter { ...@@ -133,50 +131,47 @@ impl KvRouter {
// [gluo TODO] try subscribe_with_type::<RouterEvent>, // [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different. // error checking below will be different.
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?; if let Some(ref indexer) = maybe_indexer {
let kv_events_tx = indexer.event_sender(); let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_events_tx = indexer.event_sender();
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await { tokio::spawn(async move {
let event: RouterEvent = match serde_json::from_slice(&event.payload) { while let Some(event) = kv_events_rx.next().await {
Ok(event) => event, let event: RouterEvent = match serde_json::from_slice(&event.payload) {
Err(e) => { Ok(event) => event,
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e); Err(e) => {
// Choosing warn and continue to process other events from other workers tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
// A bad event likely signals a problem with a worker, but potentially other workers are still healthy // Choosing warn and continue to process other events from other workers
continue; // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
continue;
}
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::debug!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
} }
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e);
} }
} });
}); }
tracing::info!("KV Routing initialized");
Ok(Self { Ok(Self {
indexer: maybe_indexer,
scheduler, scheduler,
indexer,
block_size, block_size,
}) })
} }
// [TODO] indexer needs to take 'lora_id' as parameter
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
// Extracting part of the code in KvRouter::generate() for only
// the decision making part, routing is done by the caller
let isl_tokens = token_ids.len();
let overlap_scores = self
.indexer
.find_matches_for_request(token_ids.as_slice())
.await?;
tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
}
/// Give these tokens, find the worker with the best match in it's KV cache. /// Give these tokens, find the worker with the best match in it's KV cache.
/// Returned overlap amount is in number of blocks. /// Returned overlap amount is in number of blocks.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> { /// Now also takes context_id for request tracking
async fn find_best_match(
&self,
context_id: &str,
tokens: &[u32],
) -> anyhow::Result<(i64, u32)> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_size = self.block_size; let block_size = self.block_size;
...@@ -187,13 +182,38 @@ impl KvRouter { ...@@ -187,13 +182,38 @@ impl KvRouter {
.into_iter() .into_iter()
.map(|block| LocalBlockHash(block.block_hash())) .map(|block| LocalBlockHash(block.block_hash()))
.collect(); .collect();
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?; let overlap_scores = match &self.indexer {
let worker_id = self Some(indexer) => indexer.find_matches(local_block_hashes).await?,
None => Default::default(), // Returns empty/default instance
};
let best_worker_id = self
.scheduler .scheduler
.schedule(overlap_scores.clone(), isl_tokens) .schedule(
context_id.to_string(),
isl_tokens,
block_size,
tokens,
overlap_scores.clone(),
)
.await?; .await?;
let overlap_amount = overlap_scores.scores.get(&worker_id).copied().unwrap_or(0);
Ok((worker_id, overlap_amount)) let overlap_amount = overlap_scores
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0);
Ok((best_worker_id, overlap_amount))
}
/// Push a token to a specific request's sequence
pub async fn push(&self, request_id: &String, token: u32) {
self.scheduler.push(request_id, token).await
}
/// Free all blocks associated with a request
pub async fn free(&self, request_id: &String) {
self.scheduler.free(request_id).await
} }
/// Get the block size this router was configured with /// Get the block size this router was configured with
...@@ -202,6 +222,7 @@ impl KvRouter { ...@@ -202,6 +222,7 @@ impl KvRouter {
} }
} }
// NOTE: this would not be usable for now, should deprecate
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter { impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
async fn generate( async fn generate(
...@@ -209,7 +230,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -209,7 +230,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
request: SingleIn<RouterRequest>, request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> { ) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts(); let (request, ctx) = request.into_parts();
let (worker_id, _) = self.find_best_match(&request.tokens).await?; let (worker_id, _) = self.find_best_match(ctx.id(), &request.tokens).await?;
let response = RouterResponse { worker_id }; let response = RouterResponse { worker_id };
let response = Annotated::from_data(response); let response = Annotated::from_data(response);
...@@ -243,13 +264,40 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -243,13 +264,40 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
match self.inner.client.instance_source.as_ref() { match self.inner.client.instance_source.as_ref() {
InstanceSource::Static => self.inner.r#static(request).await, InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => { InstanceSource::Dynamic(_) => {
let (instance_id, overlap_amount) = // Extract context ID for request tracking
self.chooser.find_best_match(&request.token_ids).await?; let context_id = request.context().id().to_string();
let (instance_id, overlap_amount) = self
.chooser
.find_best_match(&context_id, &request.token_ids)
.await?;
// Update the request with the estimated prefix hit blocks // Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
self.inner.direct(updated_request, instance_id).await
// Get the response stream from the worker
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens
let stream_context = response_stream.context();
let chooser = self.chooser.clone();
let request_id = context_id.clone();
let wrapped_stream = Box::pin(async_stream::stream! {
while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response
if let Some(ref output) = item.data {
for token_id in &output.token_ids {
chooser.push(&request_id, *token_id).await;
}
}
yield item;
}
chooser.free(&request_id).await;
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
} }
} }
} }
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
use std::sync::Once; use std::sync::Once;
pub use crate::kv_router::protocols::ForwardPassMetrics; pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::KV_METRICS_ENDPOINT; use crate::kv_router::KV_METRICS_ENDPOINT;
use crate::kv_router::scheduler::Endpoint; use crate::kv_router::scheduler::Endpoint;
...@@ -28,6 +28,37 @@ use tokio_util::sync::CancellationToken; ...@@ -28,6 +28,37 @@ use tokio_util::sync::CancellationToken;
static METRICS_WAITING_MESSAGE: Once = Once::new(); static METRICS_WAITING_MESSAGE: Once = Once::new();
static METRICS_FOUND_MESSAGE: Once = Once::new(); static METRICS_FOUND_MESSAGE: Once = Once::new();
pub struct EndpointCollector {
pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
}
impl EndpointCollector {
pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
let (watch_tx, watch_rx) = watch::channel(ProcessedEndpoints::default());
tokio::spawn(collect_endpoints_task(
component.clone(),
watch_tx,
cancellation_token.clone(),
"generate".to_string(),
));
Self {
service_name: component.service_name(),
endpoints_rx: watch_rx,
}
}
pub fn get_endpoints(&self) -> ProcessedEndpoints {
self.endpoints_rx.borrow().clone()
}
pub fn endpoints_watcher(&self) -> watch::Receiver<ProcessedEndpoints> {
self.endpoints_rx.clone()
}
}
pub struct KvMetricsAggregator { pub struct KvMetricsAggregator {
pub service_name: String, pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>, pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
...@@ -41,6 +72,7 @@ impl KvMetricsAggregator { ...@@ -41,6 +72,7 @@ impl KvMetricsAggregator {
component.clone(), component.clone(),
watch_tx, watch_tx,
cancellation_token.clone(), cancellation_token.clone(),
KV_METRICS_ENDPOINT.to_string(),
)); ));
Self { Self {
...@@ -93,12 +125,16 @@ pub async fn collect_endpoints_task( ...@@ -93,12 +125,16 @@ pub async fn collect_endpoints_task(
component: Component, component: Component,
watch_tx: watch::Sender<ProcessedEndpoints>, watch_tx: watch::Sender<ProcessedEndpoints>,
cancel: CancellationToken, cancel: CancellationToken,
subject: String,
) { ) {
let backoff_delay = Duration::from_millis(100); let backoff_delay = Duration::from_millis(100);
let scrape_timeout = Duration::from_millis(300); let scrape_timeout = Duration::from_millis(300);
let endpoint = component.endpoint(KV_METRICS_ENDPOINT); let endpoint = component.endpoint(&subject);
let service_subject = endpoint.subject(); let service_subject = endpoint.subject();
// Keep track of the last sent value to avoid unnecessary updates
let mut last_sent: Option<ProcessedEndpoints> = None;
loop { loop {
tokio::select! { tokio::select! {
_ = cancel.cancelled() => { _ = cancel.cancelled() => {
...@@ -115,30 +151,58 @@ pub async fn collect_endpoints_task( ...@@ -115,30 +151,58 @@ pub async fn collect_endpoints_task(
continue; continue;
} }
}; };
let endpoints: Vec<Endpoint> = unfiltered_endpoints
.into_iter() let endpoints: Vec<Endpoint> = if subject == KV_METRICS_ENDPOINT {
.filter(|s| s.data.is_some()) // Original filtering behavior
.filter_map(|s| unfiltered_endpoints
match s.data.unwrap().decode::<ForwardPassMetrics>() { .into_iter()
Ok(data) => Some(Endpoint { .filter_map(|s| {
name: s.name, s.data?
subject: s.subject, .decode::<ForwardPassMetrics>()
data, .map(|data| Endpoint {
}), name: s.name,
Err(e) => { subject: s.subject,
tracing::debug!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e); data: LoadMetrics::EngineLoadMetrics(data),
None })
} .inspect_err(|e| {
} tracing::warn!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e);
) })
.collect(); .ok()
})
.collect()
} else {
// No filtering - just use default LoadMetrics
unfiltered_endpoints
.into_iter()
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: LoadMetrics::default(),
})
.collect()
};
tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len()); tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len());
let processed = ProcessedEndpoints::new(endpoints); let processed = ProcessedEndpoints::new(endpoints);
if watch_tx.send(processed).is_err() { // Only send if different from last sent value
tracing::trace!("failed to send processed endpoints; shutting down"); // This is necessary because the watch channel does not track changes
break; // https://docs.rs/tokio/latest/tokio/sync/watch/struct.Receiver.html#method.has_changed
let should_send = match &last_sent {
Some(last) => last != &processed,
None => true,
};
if should_send {
tracing::trace!("Endpoints changed, sending update for service: {service_subject}");
if watch_tx.send(processed.clone()).is_err() {
tracing::error!("failed to send processed endpoints; shutting down");
break;
}
last_sent = Some(processed);
} else {
tracing::trace!("Endpoints unchanged, skipping update for service: {service_subject}");
} }
} }
} }
......
...@@ -39,14 +39,14 @@ pub struct WorkerSelectionResult { ...@@ -39,14 +39,14 @@ pub struct WorkerSelectionResult {
pub overlap_blocks: usize, pub overlap_blocks: usize,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct ForwardPassMetrics { pub struct ForwardPassMetrics {
pub worker_stats: WorkerStats, pub worker_stats: WorkerStats,
pub kv_stats: KvStats, pub kv_stats: KvStats,
pub spec_decode_stats: Option<SpecDecodeStats>, pub spec_decode_stats: Option<SpecDecodeStats>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct WorkerStats { pub struct WorkerStats {
// https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models // https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models
pub data_parallel_rank: Option<u32>, pub data_parallel_rank: Option<u32>,
...@@ -55,7 +55,7 @@ pub struct WorkerStats { ...@@ -55,7 +55,7 @@ pub struct WorkerStats {
pub num_requests_waiting: u64, pub num_requests_waiting: u64,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct KvStats { pub struct KvStats {
pub kv_active_blocks: u64, pub kv_active_blocks: u64,
pub kv_total_blocks: u64, pub kv_total_blocks: u64,
...@@ -65,7 +65,34 @@ pub struct KvStats { ...@@ -65,7 +65,34 @@ pub struct KvStats {
pub gpu_prefix_cache_hit_rate: f32, pub gpu_prefix_cache_hit_rate: f32,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct PredictiveLoadMetrics {
pub kv_active_blocks: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum LoadMetrics {
EngineLoadMetrics(ForwardPassMetrics),
PredictiveLoadMetrics(PredictiveLoadMetrics),
}
impl LoadMetrics {
pub fn kv_active_blocks(&self) -> u64 {
match self {
LoadMetrics::EngineLoadMetrics(metrics) => metrics.kv_stats.kv_active_blocks,
LoadMetrics::PredictiveLoadMetrics(metrics) => metrics.kv_active_blocks,
}
}
}
impl Default for LoadMetrics {
fn default() -> Self {
LoadMetrics::PredictiveLoadMetrics(PredictiveLoadMetrics::default())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct SpecDecodeStats { pub struct SpecDecodeStats {
pub num_spec_tokens: Option<u32>, pub num_spec_tokens: Option<u32>,
pub num_drafts: Option<u32>, pub num_drafts: Option<u32>,
......
This diff is collapsed.
...@@ -20,7 +20,7 @@ use std::collections::HashMap; ...@@ -20,7 +20,7 @@ use std::collections::HashMap;
use crate::kv_router::scheduler::Endpoint; use crate::kv_router::scheduler::Endpoint;
#[derive(Debug, Default, Serialize, Deserialize, Clone)] #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct ProcessedEndpoints { pub struct ProcessedEndpoints {
pub endpoints: HashMap<i64, Endpoint>, pub endpoints: HashMap<i64, Endpoint>,
pub load_avg: f64, pub load_avg: f64,
...@@ -32,7 +32,7 @@ impl ProcessedEndpoints { ...@@ -32,7 +32,7 @@ impl ProcessedEndpoints {
// compute some basic statistics // compute some basic statistics
let load_values: Vec<f64> = endpoints let load_values: Vec<f64> = endpoints
.iter() .iter()
.map(|x| x.data.kv_stats.kv_active_blocks as f64) .map(|endpoint| endpoint.data.kv_active_blocks() as f64)
.collect(); .collect();
let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64; let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64;
let variance = load_values let variance = load_values
...@@ -50,4 +50,15 @@ impl ProcessedEndpoints { ...@@ -50,4 +50,15 @@ impl ProcessedEndpoints {
load_std, load_std,
} }
} }
pub fn worker_ids(&self) -> Vec<i64> {
self.endpoints.keys().copied().collect()
}
pub fn active_blocks(&self) -> HashMap<i64, usize> {
self.endpoints
.iter()
.map(|(&worker_id, endpoint)| (worker_id, endpoint.data.kv_active_blocks() as usize))
.collect()
}
} }
This diff is collapsed.
...@@ -46,8 +46,9 @@ ...@@ -46,8 +46,9 @@
//! implementation of the main block manager. //! implementation of the main block manager.
use crate::mocker::evictor::LRUEvictor; use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost, UniqueBlock}; use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost};
use crate::mocker::sequence::ActiveSequence; use crate::mocker::sequence::ActiveSequence;
use crate::tokens::blocks::UniqueBlock;
use derive_getters::Getters; use derive_getters::Getters;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use tokio::sync::mpsc; use tokio::sync::mpsc;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment