Unverified Commit 84e71e27 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: predictive active blocks for routing without load metrics (#1731)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarAlec <35311602+alec-flowers@users.noreply.github.com>
parent ffccc722
...@@ -83,7 +83,7 @@ use serde::{Deserialize, Serialize}; ...@@ -83,7 +83,7 @@ use serde::{Deserialize, Serialize};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration as StdDuration; use std::time::Duration as StdDuration;
use dynamo_llm::kv_router::protocols::ForwardPassMetrics; use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics};
use dynamo_llm::kv_router::scheduler::Endpoint; use dynamo_llm::kv_router::scheduler::Endpoint;
use dynamo_llm::kv_router::scoring::ProcessedEndpoints; use dynamo_llm::kv_router::scoring::ProcessedEndpoints;
...@@ -449,7 +449,10 @@ impl PrometheusMetrics { ...@@ -449,7 +449,10 @@ impl PrometheusMetrics {
// Update per-worker metrics // Update per-worker metrics
for (worker_id, endpoint) in processed.endpoints.iter() { for (worker_id, endpoint) in processed.endpoints.iter() {
let worker_id = worker_id.to_string(); let worker_id = worker_id.to_string();
let metrics = endpoint.data.clone(); let load_metrics = endpoint.data.clone();
let LoadMetrics::EngineLoadMetrics(metrics) = load_metrics else {
panic!("Can only update with ForwardPassMetrics");
};
self.set_worker_gauge( self.set_worker_gauge(
&self.kv_blocks_active, &self.kv_blocks_active,
...@@ -602,7 +605,7 @@ pub fn postprocess_metrics( ...@@ -602,7 +605,7 @@ pub fn postprocess_metrics(
e.id().ok().map(|id| Endpoint { e.id().ok().map(|id| Endpoint {
name: format!("worker-{id}"), name: format!("worker-{id}"),
subject: e.subject.clone(), subject: e.subject.clone(),
data: m.clone(), data: LoadMetrics::EngineLoadMetrics(m.clone()),
}) })
}) })
.collect(); .collect();
......
...@@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> {
let selector = Box::new(CustomWorkerSelector::default()); let selector = Box::new(CustomWorkerSelector::default());
let router = KvRouter::new(component.clone(), args.block_size, Some(selector)).await?; let router = KvRouter::new(component.clone(), args.block_size, Some(selector), true).await?;
let router = Ingress::for_engine(Arc::new(router))?; let router = Ingress::for_engine(Arc::new(router))?;
component component
......
...@@ -3392,14 +3392,8 @@ index cafd8150b..6a5e45b4e 100644 ...@@ -3392,14 +3392,8 @@ index cafd8150b..6a5e45b4e 100644
+ num_requests_waiting: int + num_requests_waiting: int
+ gpu_cache_usage_perc: float + gpu_cache_usage_perc: float
+ gpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float
+ spec_decode_draft_acceptance_rate: Optional[float] = None
+ spec_decode_system_efficiency: Optional[float] = None
+ spec_decode_draft_tokens: Optional[int] = None
+ spec_decode_emitted_tokens: Optional[int] = None
+ spec_decode_accepted_tokens: Optional[int] = None
+ spec_decode_num_spec_tokens: Optional[int] = None
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index f058b1329..2fdb5b8bf 100644 index f058b1329..fd5610a3c 100644
--- a/vllm/engine/multiprocessing/client.py --- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py
@@ -1,4 +1,17 @@ @@ -1,4 +1,17 @@
...@@ -3460,16 +3454,25 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3460,16 +3454,25 @@ index f058b1329..2fdb5b8bf 100644
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
@@ -48,6 +66,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest @@ -48,6 +66,17 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import Device, deprecate_kwargs from vllm.utils import Device, deprecate_kwargs
+from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback +from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback
+from vllm.distributed.device_communicators.nixl import NixlMetadata +from vllm.distributed.device_communicators.nixl import NixlMetadata
+
+# Import ForwardPassMetrics and related classes from dynamo
+try:
+ from dynamo.llm import ForwardPassMetrics, WorkerStats, KvStats
+except ImportError:
+ # Fallback if dynamo imports are not available
+ ForwardPassMetrics = None
+ WorkerStats = None
+ KvStats = None
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -93,6 +113,7 @@ class MQLLMEngineClient(EngineClient): @@ -93,6 +122,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
# Get the configs. # Get the configs.
...@@ -3477,7 +3480,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3477,7 +3480,7 @@ index f058b1329..2fdb5b8bf 100644
self.model_config = engine_config.model_config self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config self.decoding_config = engine_config.decoding_config
@@ -117,6 +138,10 @@ class MQLLMEngineClient(EngineClient): @@ -117,6 +147,10 @@ class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
...@@ -3488,7 +3491,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3488,7 +3491,7 @@ index f058b1329..2fdb5b8bf 100644
# IPC path for the data socket. # IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -131,8 +156,27 @@ class MQLLMEngineClient(EngineClient): @@ -131,8 +165,27 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically. # Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready. # Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None self.health_loop: Optional[asyncio.Task] = None
...@@ -3516,7 +3519,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3516,7 +3519,7 @@ index f058b1329..2fdb5b8bf 100644
@staticmethod @staticmethod
def is_unsupported_config(vllm_config: VllmConfig): def is_unsupported_config(vllm_config: VllmConfig):
# Pipeline parallel not yet supported # Pipeline parallel not yet supported
@@ -182,6 +226,61 @@ class MQLLMEngineClient(EngineClient): @@ -182,6 +235,76 @@ class MQLLMEngineClient(EngineClient):
except Exception as e: except Exception as e:
self._set_errored(e) self._set_errored(e)
...@@ -3553,13 +3556,28 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3553,13 +3556,28 @@ index f058b1329..2fdb5b8bf 100644
+ if self.metrics_publisher is not None and isinstance( + if self.metrics_publisher is not None and isinstance(
+ metrics, KvMetrics + metrics, KvMetrics
+ ): + ):
+ self.metrics_publisher.publish(metrics.request_active_slots, + # Construct structured metrics objects
+ metrics.request_total_slots, + worker_stats = WorkerStats(
+ metrics.kv_active_blocks, + request_active_slots=metrics.request_active_slots,
+ metrics.kv_total_blocks, + request_total_slots=metrics.request_total_slots,
+ metrics.num_requests_waiting, + num_requests_waiting=metrics.num_requests_waiting,
+ metrics.gpu_cache_usage_perc, + data_parallel_rank=None
+ metrics.gpu_prefix_cache_hit_rate) + )
+
+ kv_stats = KvStats(
+ kv_active_blocks=metrics.kv_active_blocks,
+ kv_total_blocks=metrics.kv_total_blocks,
+ gpu_cache_usage_perc=metrics.gpu_cache_usage_perc,
+ gpu_prefix_cache_hit_rate=metrics.gpu_prefix_cache_hit_rate
+ )
+
+ forward_pass_metrics = ForwardPassMetrics(
+ worker_stats=worker_stats,
+ kv_stats=kv_stats,
+ spec_decode_stats=None
+ )
+
+ self.metrics_publisher.publish(forward_pass_metrics)
+ logger.debug("Metrics successful.") + logger.debug("Metrics successful.")
+ +
+ # TODO: Investigate sending whole stats object + # TODO: Investigate sending whole stats object
...@@ -3578,7 +3596,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3578,7 +3596,7 @@ index f058b1329..2fdb5b8bf 100644
async def run_output_handler_loop(self): async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues""" """Get RequestOutputs from Engine and stream to Request Queues"""
@@ -250,7 +349,7 @@ class MQLLMEngineClient(EngineClient): @@ -250,7 +373,7 @@ class MQLLMEngineClient(EngineClient):
# Put each output into the appropriate queue. # Put each output into the appropriate queue.
elif isinstance( elif isinstance(
request_outputs, request_outputs,
...@@ -3587,7 +3605,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3587,7 +3605,7 @@ index f058b1329..2fdb5b8bf 100644
self._add_output(request_outputs) self._add_output(request_outputs)
else: else:
for request_output in request_outputs: for request_output in request_outputs:
@@ -261,7 +360,7 @@ class MQLLMEngineClient(EngineClient): @@ -261,7 +384,7 @@ class MQLLMEngineClient(EngineClient):
def _add_output(self, request_output: Union[RequestOutput, def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse, RPCAdapterLoadedResponse,
...@@ -3596,7 +3614,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3596,7 +3614,7 @@ index f058b1329..2fdb5b8bf 100644
queue = self.output_queues.get(request_output.request_id) queue = self.output_queues.get(request_output.request_id)
if queue is not None: if queue is not None:
queue.put_nowait(request_output) queue.put_nowait(request_output)
@@ -283,12 +382,25 @@ class MQLLMEngineClient(EngineClient): @@ -283,12 +406,25 @@ class MQLLMEngineClient(EngineClient):
# Wait until server is ready. # Wait until server is ready.
response = await self._wait_for_server_rpc(socket) response = await self._wait_for_server_rpc(socket)
...@@ -3622,7 +3640,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3622,7 +3640,7 @@ index f058b1329..2fdb5b8bf 100644
def close(self): def close(self):
"""Destroy the ZeroMQ Context.""" """Destroy the ZeroMQ Context."""
@@ -298,6 +410,8 @@ class MQLLMEngineClient(EngineClient): @@ -298,6 +434,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks. # Cancel background tasks.
if self.health_loop is not None: if self.health_loop is not None:
self.health_loop.cancel() self.health_loop.cancel()
...@@ -3631,7 +3649,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3631,7 +3649,7 @@ index f058b1329..2fdb5b8bf 100644
if self.output_loop is not None: if self.output_loop is not None:
self.output_loop.cancel() self.output_loop.cancel()
@@ -420,6 +534,9 @@ class MQLLMEngineClient(EngineClient): @@ -420,6 +558,9 @@ class MQLLMEngineClient(EngineClient):
""" """
if self._errored_with is not None: if self._errored_with is not None:
raise self._errored_with raise self._errored_with
...@@ -3641,7 +3659,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3641,7 +3659,7 @@ index f058b1329..2fdb5b8bf 100644
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@@ -478,6 +595,7 @@ class MQLLMEngineClient(EngineClient): @@ -478,6 +619,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
...@@ -3649,7 +3667,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3649,7 +3667,7 @@ index f058b1329..2fdb5b8bf 100644
*, *,
inputs: Optional[PromptType] = None # DEPRECATED inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
@@ -507,7 +625,8 @@ class MQLLMEngineClient(EngineClient): @@ -507,7 +649,8 @@ class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id, return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers,
...@@ -3659,7 +3677,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3659,7 +3677,7 @@ index f058b1329..2fdb5b8bf 100644
@overload @overload
def encode( def encode(
@@ -591,6 +710,7 @@ class MQLLMEngineClient(EngineClient): @@ -591,6 +734,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
...@@ -3667,7 +3685,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3667,7 +3685,7 @@ index f058b1329..2fdb5b8bf 100644
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]: PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -636,6 +756,12 @@ class MQLLMEngineClient(EngineClient): @@ -636,6 +780,12 @@ class MQLLMEngineClient(EngineClient):
else: else:
lp_bytes = None lp_bytes = None
...@@ -3680,7 +3698,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3680,7 +3698,7 @@ index f058b1329..2fdb5b8bf 100644
request_bytes = pickle.dumps( request_bytes = pickle.dumps(
RPCProcessRequest( RPCProcessRequest(
prompt=prompt, prompt=prompt,
@@ -645,11 +771,11 @@ class MQLLMEngineClient(EngineClient): @@ -645,11 +795,11 @@ class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
...@@ -3694,7 +3712,7 @@ index f058b1329..2fdb5b8bf 100644 ...@@ -3694,7 +3712,7 @@ index f058b1329..2fdb5b8bf 100644
await self.input_socket.send_multipart(parts, copy=False) await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note # 4) Stream the RequestOutputs from the output queue. Note
@@ -740,3 +866,22 @@ class MQLLMEngineClient(EngineClient): @@ -740,3 +890,22 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None # Raise on error, otherwise happily return None
if isinstance(request_output, BaseException): if isinstance(request_output, BaseException):
raise request_output raise request_output
......
...@@ -15,6 +15,8 @@ See the License for the specific language governing permissions and ...@@ -15,6 +15,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
--> -->
>[!NOTE]
>This information is temporary and will change soon.
# KV Cache Routing # KV Cache Routing
This documentation explains how Key-Value (KV) cache routing works in Dynamo, providing optimized inference for large language models by intelligently directing requests to workers with the most relevant cached data while simultaneously load balancing based on utilization metrics sent by the workers. This documentation explains how Key-Value (KV) cache routing works in Dynamo, providing optimized inference for large language models by intelligently directing requests to workers with the most relevant cached data while simultaneously load balancing based on utilization metrics sent by the workers.
......
...@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm. ...@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
Usage: Usage:
``` ```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=2.0] [--kv-gpu-cache-usage-weight=1.0] [--kv-waiting-requests-weight=1.0] [--verbosity (-v|-vv)] dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--verbosity (-v|-vv)]
``` ```
Example: `dynamo run Qwen/Qwen3-0.6B` Example: `dynamo run Qwen/Qwen3-0.6B`
...@@ -201,6 +201,8 @@ The only difference from the distributed system above is `--router-mode kv`. The ...@@ -201,6 +201,8 @@ The only difference from the distributed system above is `--router-mode kv`. The
For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing. For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing.
The argument `--kv-overlap-score-weight` sets the amount weighting on overlaps with prefix caches, which directly contributes to the prefill cost, so a large weight is expected to yield a better TTFT (at the expense of worse ITL). When this is set 0, we do not consider the prefix caches at all (falling back to pure load balancing behavior on the active blocks), in which case we do not require the backend engines to emit any KV events. The argument `--router-temperature` sets the temperature when randomly selecting the workers to route to via softmax sampling on the router cost logits, setting it to 0 recovers the deterministic behavior where the min logit is picked.
## Full usage details ## Full usage details
`dynamo run` executes `dynamo-run`. `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features. `dynamo run` executes `dynamo-run`. `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features.
......
...@@ -15,6 +15,9 @@ See the License for the specific language governing permissions and ...@@ -15,6 +15,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
--> -->
>[!NOTE]
>This information is temporary and will change soon.
# KV Router Performance Tuning # KV Router Performance Tuning
## Overview ## Overview
......
...@@ -110,20 +110,23 @@ pub struct Flags { ...@@ -110,20 +110,23 @@ pub struct Flags {
#[arg(long, default_value = "round-robin")] #[arg(long, default_value = "round-robin")]
pub router_mode: RouterMode, pub router_mode: RouterMode,
/// Maximum number of batched tokens for KV routing
/// Needed for informing the KV router
/// TODO: derive from vllm args
/// NOTE: this is not actually used for now
#[arg(long, default_value = "8192")]
pub max_num_batched_tokens: Option<u32>,
/// KV Router: Weight for overlap score in worker selection. /// KV Router: Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 2.0 /// Higher values prioritize KV cache reuse. Default: 2.0
#[arg(long)] #[arg(long)]
pub kv_overlap_score_weight: Option<f64>, pub kv_overlap_score_weight: Option<f64>,
/// KV Router: Weight for GPU cache usage in worker selection. /// KV Router: Temperature for worker sampling via softmax.
/// Higher values avoid workers with nearly full KV caches. Default: 1.0 /// Higher values promote more randomness, and 0 fallbacks to deterministic.
#[arg(long)] /// Default: 0.5
pub kv_gpu_cache_usage_weight: Option<f64>,
/// KV Router: Weight for waiting requests in worker selection.
/// Higher values avoid workers with queued requests. Default: 1.0
#[arg(long)] #[arg(long)]
pub kv_waiting_requests_weight: Option<f64>, pub router_temperature: Option<f64>,
/// Max model context length. Reduce this if you don't have enough VRAM for the full model /// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4). /// context length (e.g. Llama 4).
...@@ -211,8 +214,8 @@ impl Flags { ...@@ -211,8 +214,8 @@ impl Flags {
self.router_mode.into(), self.router_mode.into(),
KvRouterConfig::new( KvRouterConfig::new(
self.kv_overlap_score_weight, self.kv_overlap_score_weight,
self.kv_gpu_cache_usage_weight, self.router_temperature,
self.kv_waiting_requests_weight, self.max_num_batched_tokens,
), ),
) )
} }
......
...@@ -26,7 +26,14 @@ from vllm.entrypoints.openai.api_server import ( ...@@ -26,7 +26,14 @@ from vllm.entrypoints.openai.api_server import (
) )
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from dynamo.llm import ModelType, WorkerMetricsPublisher, register_llm from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelType,
WorkerMetricsPublisher,
WorkerStats,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -70,15 +77,29 @@ class RequestHandler: ...@@ -70,15 +77,29 @@ class RequestHandler:
self.engine_client.set_metrics_publisher(self.metrics_publisher) self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start, # Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered # vLLM will not update stat until forward pass is triggered
self.metrics_publisher.publish(
0, # request_active_slots # Create the structured metrics objects
1024, # request_total_slots worker_stats = WorkerStats(
0, # kv_active_blocks request_active_slots=0,
1024, # kv_total_blocks request_total_slots=1024,
0, # num_requests_waiting num_requests_waiting=0,
0.0, # gpu_cache_usage_perc data_parallel_rank=None,
0.0, # gpu_prefix_cache_hit_rate )
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=1024,
gpu_cache_usage_perc=0.0,
gpu_prefix_cache_hit_rate=0.0,
) )
metrics = ForwardPassMetrics(
worker_stats=worker_stats, kv_stats=kv_stats, spec_decode_stats=None
)
# Publish the metrics as a single object
self.metrics_publisher.publish(metrics)
task = asyncio.create_task(self.create_metrics_publisher_endpoint()) task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback( task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created") lambda _: logging.debug("metrics publisher endpoint created")
......
...@@ -72,7 +72,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -72,7 +72,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Client>()?; m.add_class::<Client>()?;
m.add_class::<EtcdClient>()?; m.add_class::<EtcdClient>()?;
m.add_class::<AsyncResponseStream>()?; m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::disagg_router::DisaggregatedRouter>()?; m.add_class::<llm::disagg_router::DisaggregatedRouter>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?; m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?; m.add_class::<llm::model_card::ModelDeploymentCard>()?;
......
...@@ -29,51 +29,6 @@ use tracing; ...@@ -29,51 +29,6 @@ use tracing;
use llm_rs::kv_router::protocols::*; use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig}; use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig};
#[pyclass]
pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>,
}
#[pymethods]
impl KvRouter {
#[new]
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
};
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner = llm_rs::kv_router::KvRouter::new(
component.inner.clone(),
kv_block_size as u32,
None,
)
.await
.map_err(to_pyerr)?;
Ok(Self {
inner: Arc::new(inner),
})
})
}
fn schedule<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
lora_id: u64,
) -> PyResult<Bound<'p, PyAny>> {
let router = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker_id = router
.schedule(&token_ids, lora_id)
.await
.map_err(to_pyerr)?;
Ok(worker_id)
})
}
}
#[pyfunction] #[pyfunction]
pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> { pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> {
if kv_block_size == 0 { if kv_block_size == 0 {
...@@ -617,25 +572,34 @@ impl KvMetricsAggregator { ...@@ -617,25 +572,34 @@ impl KvMetricsAggregator {
fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> { fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
// TODO: update EndpointKvMetrics to match the new ForwardPassMetrics struct // TODO: update EndpointKvMetrics to match the new ForwardPassMetrics struct
let endpoints = self.inner.get_endpoints(); let endpoints = self.inner.get_endpoints();
let load_avg = endpoints.load_avg;
let load_std = endpoints.load_std;
let endpoint_kv_metrics = endpoints let endpoint_kv_metrics = endpoints
.endpoints .endpoints
.iter() .into_iter()
.map(|(worker_id, endpoint)| EndpointKvMetrics { .map(|(worker_id, endpoint)| {
worker_id: *worker_id, let metrics = endpoint.data;
request_active_slots: endpoint.data.worker_stats.request_active_slots, let LoadMetrics::EngineLoadMetrics(fwd_pass_metrics) = metrics else {
request_total_slots: endpoint.data.worker_stats.request_total_slots, panic!("Endpoints do not contain forward pass metrics.");
kv_active_blocks: endpoint.data.kv_stats.kv_active_blocks, };
kv_total_blocks: endpoint.data.kv_stats.kv_total_blocks, EndpointKvMetrics {
num_requests_waiting: endpoint.data.worker_stats.num_requests_waiting, worker_id,
gpu_cache_usage_perc: endpoint.data.kv_stats.gpu_cache_usage_perc, request_active_slots: fwd_pass_metrics.worker_stats.request_active_slots,
gpu_prefix_cache_hit_rate: endpoint.data.kv_stats.gpu_prefix_cache_hit_rate, request_total_slots: fwd_pass_metrics.worker_stats.request_total_slots,
kv_active_blocks: fwd_pass_metrics.kv_stats.kv_active_blocks,
kv_total_blocks: fwd_pass_metrics.kv_stats.kv_total_blocks,
num_requests_waiting: fwd_pass_metrics.worker_stats.num_requests_waiting,
gpu_cache_usage_perc: fwd_pass_metrics.kv_stats.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: fwd_pass_metrics.kv_stats.gpu_prefix_cache_hit_rate,
}
}) })
.collect(); .collect();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(AggregatedMetrics { Ok(AggregatedMetrics {
endpoints: endpoint_kv_metrics, endpoints: endpoint_kv_metrics,
load_avg: endpoints.load_avg, load_avg,
load_std: endpoints.load_std, load_std,
}) })
}) })
} }
......
...@@ -272,25 +272,6 @@ class Client: ...@@ -272,25 +272,6 @@ class Client:
""" """
... ...
class KvRouter:
"""
A router will determine which worker should handle a given request.
"""
...
def __init__(self, drt: DistributedRuntime, component: Component) -> None:
"""
Create a `KvRouter` object that is associated with the `component`
"""
def schedule(self, token_ids: List[int], lora_id: int) -> int:
"""
Return the worker id that should handle the given token ids,
exception will be raised if there is no worker available.
"""
...
class DisaggregatedRouter: class DisaggregatedRouter:
""" """
A router that determines whether to perform prefill locally or remotely based on A router that determines whether to perform prefill locally or remotely based on
......
...@@ -32,7 +32,6 @@ from dynamo._core import KvEventPublisher as KvEventPublisher ...@@ -32,7 +32,6 @@ from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvRecorder as KvRecorder from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouter as KvRouter
from dynamo._core import KvStats as KvStats from dynamo._core import KvStats as KvStats
from dynamo._core import ModelType as ModelType from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores from dynamo._core import OverlapScores as OverlapScores
......
...@@ -212,8 +212,20 @@ impl ModelManager { ...@@ -212,8 +212,20 @@ impl ModelManager {
kv_cache_block_size: u32, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> { ) -> anyhow::Result<Arc<KvRouter>> {
// Determine if we should use KV events based on overlap score weight
let use_kv_events = kv_router_config
.as_ref()
.map(|config| config.overlap_score_weight > 0.0)
.unwrap_or(false);
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new(component.clone(), kv_cache_block_size, Some(selector)).await?; let chooser = KvRouter::new(
component.clone(),
kv_cache_block_size,
Some(selector),
use_kv_events,
)
.await?;
let new_kv_chooser = Arc::new(chooser); let new_kv_chooser = Arc::new(chooser);
self.kv_choosers self.kv_choosers
.lock() .lock()
......
...@@ -23,11 +23,12 @@ pub mod publisher; ...@@ -23,11 +23,12 @@ pub mod publisher;
pub mod recorder; pub mod recorder;
pub mod scheduler; pub mod scheduler;
pub mod scoring; pub mod scoring;
pub mod sequence;
use crate::{ use crate::{
kv_router::{ kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
metrics_aggregator::KvMetricsAggregator, metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
...@@ -58,25 +59,20 @@ pub trait WorkerSelector { ...@@ -58,25 +59,20 @@ pub trait WorkerSelector {
/// KV Router configuration parameters /// KV Router configuration parameters
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct KvRouterConfig { pub struct KvRouterConfig {
/// Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 2.0
pub overlap_score_weight: f64, pub overlap_score_weight: f64,
/// Weight for GPU cache usage in worker selection. pub router_temperature: f64,
/// Higher values avoid workers with nearly full KV caches. Default: 1.0
pub gpu_cache_usage_weight: f64,
/// Weight for waiting requests in worker selection. // note: this is not actually used for now
/// Higher values avoid workers with queued requests. Default: 1.0 pub max_num_batched_tokens: u32,
pub waiting_requests_weight: f64,
} }
impl Default for KvRouterConfig { impl Default for KvRouterConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
overlap_score_weight: 1.0, overlap_score_weight: 1.0,
gpu_cache_usage_weight: 1.0, router_temperature: 0.5,
waiting_requests_weight: 1.0, max_num_batched_tokens: 8192,
} }
} }
} }
...@@ -86,16 +82,15 @@ impl KvRouterConfig { ...@@ -86,16 +82,15 @@ impl KvRouterConfig {
/// If a weight is None, the default value will be used. /// If a weight is None, the default value will be used.
pub fn new( pub fn new(
overlap_score_weight: Option<f64>, overlap_score_weight: Option<f64>,
gpu_cache_usage_weight: Option<f64>, temperature: Option<f64>,
waiting_requests_weight: Option<f64>, max_num_batched_tokens: Option<u32>,
) -> Self { ) -> Self {
let default = Self::default(); let default = Self::default();
Self { Self {
overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight), overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
gpu_cache_usage_weight: gpu_cache_usage_weight router_temperature: temperature.unwrap_or(default.router_temperature),
.unwrap_or(default.gpu_cache_usage_weight), max_num_batched_tokens: max_num_batched_tokens
waiting_requests_weight: waiting_requests_weight .unwrap_or(default.max_num_batched_tokens),
.unwrap_or(default.waiting_requests_weight),
} }
} }
} }
...@@ -103,7 +98,7 @@ impl KvRouterConfig { ...@@ -103,7 +98,7 @@ impl KvRouterConfig {
/// A KvRouter only decides which worker you should use. It doesn't send you there. /// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route. /// TODO: Rename this to indicate it only selects a worker, it does not route.
pub struct KvRouter { pub struct KvRouter {
indexer: KvIndexer, indexer: Option<KvIndexer>,
scheduler: KvScheduler, scheduler: KvScheduler,
block_size: u32, block_size: u32,
} }
...@@ -113,16 +108,19 @@ impl KvRouter { ...@@ -113,16 +108,19 @@ impl KvRouter {
component: Component, component: Component,
block_size: u32, block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
use_kv_events: bool,
) -> Result<Self> { ) -> Result<Self> {
let cancellation_token = component let cancellation_token = component
.drt() .drt()
.primary_lease() .primary_lease()
.expect("Cannot KV route static workers") .expect("Cannot KV route static workers")
.primary_token(); .primary_token();
tracing::info!("KV Routing initialized");
let metrics_aggregator = let metrics_aggregator =
KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await; EndpointCollector::new(component.clone(), cancellation_token.clone()).await;
let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
let maybe_indexer =
use_kv_events.then(|| KvIndexer::new(cancellation_token.clone(), block_size));
let scheduler = KvScheduler::start( let scheduler = KvScheduler::start(
component.namespace().clone(), component.namespace().clone(),
block_size, block_size,
...@@ -133,50 +131,47 @@ impl KvRouter { ...@@ -133,50 +131,47 @@ impl KvRouter {
// [gluo TODO] try subscribe_with_type::<RouterEvent>, // [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different. // error checking below will be different.
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?; if let Some(ref indexer) = maybe_indexer {
let kv_events_tx = indexer.event_sender(); let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_events_tx = indexer.event_sender();
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await { tokio::spawn(async move {
let event: RouterEvent = match serde_json::from_slice(&event.payload) { while let Some(event) = kv_events_rx.next().await {
Ok(event) => event, let event: RouterEvent = match serde_json::from_slice(&event.payload) {
Err(e) => { Ok(event) => event,
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e); Err(e) => {
// Choosing warn and continue to process other events from other workers tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
// A bad event likely signals a problem with a worker, but potentially other workers are still healthy // Choosing warn and continue to process other events from other workers
continue; // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
continue;
}
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::debug!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
} }
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e);
} }
} });
}); }
tracing::info!("KV Routing initialized");
Ok(Self { Ok(Self {
indexer: maybe_indexer,
scheduler, scheduler,
indexer,
block_size, block_size,
}) })
} }
// [TODO] indexer needs to take 'lora_id' as parameter
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
// Extracting part of the code in KvRouter::generate() for only
// the decision making part, routing is done by the caller
let isl_tokens = token_ids.len();
let overlap_scores = self
.indexer
.find_matches_for_request(token_ids.as_slice())
.await?;
tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
}
/// Give these tokens, find the worker with the best match in it's KV cache. /// Give these tokens, find the worker with the best match in it's KV cache.
/// Returned overlap amount is in number of blocks. /// Returned overlap amount is in number of blocks.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> { /// Now also takes context_id for request tracking
async fn find_best_match(
&self,
context_id: &str,
tokens: &[u32],
) -> anyhow::Result<(i64, u32)> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_size = self.block_size; let block_size = self.block_size;
...@@ -187,13 +182,38 @@ impl KvRouter { ...@@ -187,13 +182,38 @@ impl KvRouter {
.into_iter() .into_iter()
.map(|block| LocalBlockHash(block.block_hash())) .map(|block| LocalBlockHash(block.block_hash()))
.collect(); .collect();
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?; let overlap_scores = match &self.indexer {
let worker_id = self Some(indexer) => indexer.find_matches(local_block_hashes).await?,
None => Default::default(), // Returns empty/default instance
};
let best_worker_id = self
.scheduler .scheduler
.schedule(overlap_scores.clone(), isl_tokens) .schedule(
context_id.to_string(),
isl_tokens,
block_size,
tokens,
overlap_scores.clone(),
)
.await?; .await?;
let overlap_amount = overlap_scores.scores.get(&worker_id).copied().unwrap_or(0);
Ok((worker_id, overlap_amount)) let overlap_amount = overlap_scores
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0);
Ok((best_worker_id, overlap_amount))
}
/// Push a token to a specific request's sequence
pub async fn push(&self, request_id: &String, token: u32) {
self.scheduler.push(request_id, token).await
}
/// Free all blocks associated with a request
pub async fn free(&self, request_id: &String) {
self.scheduler.free(request_id).await
} }
/// Get the block size this router was configured with /// Get the block size this router was configured with
...@@ -202,6 +222,7 @@ impl KvRouter { ...@@ -202,6 +222,7 @@ impl KvRouter {
} }
} }
// NOTE: this would not be usable for now, should deprecate
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter { impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
async fn generate( async fn generate(
...@@ -209,7 +230,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -209,7 +230,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
request: SingleIn<RouterRequest>, request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> { ) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts(); let (request, ctx) = request.into_parts();
let (worker_id, _) = self.find_best_match(&request.tokens).await?; let (worker_id, _) = self.find_best_match(ctx.id(), &request.tokens).await?;
let response = RouterResponse { worker_id }; let response = RouterResponse { worker_id };
let response = Annotated::from_data(response); let response = Annotated::from_data(response);
...@@ -243,13 +264,40 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -243,13 +264,40 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
match self.inner.client.instance_source.as_ref() { match self.inner.client.instance_source.as_ref() {
InstanceSource::Static => self.inner.r#static(request).await, InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => { InstanceSource::Dynamic(_) => {
let (instance_id, overlap_amount) = // Extract context ID for request tracking
self.chooser.find_best_match(&request.token_ids).await?; let context_id = request.context().id().to_string();
let (instance_id, overlap_amount) = self
.chooser
.find_best_match(&context_id, &request.token_ids)
.await?;
// Update the request with the estimated prefix hit blocks // Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
self.inner.direct(updated_request, instance_id).await
// Get the response stream from the worker
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens
let stream_context = response_stream.context();
let chooser = self.chooser.clone();
let request_id = context_id.clone();
let wrapped_stream = Box::pin(async_stream::stream! {
while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response
if let Some(ref output) = item.data {
for token_id in &output.token_ids {
chooser.push(&request_id, *token_id).await;
}
}
yield item;
}
chooser.free(&request_id).await;
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
} }
} }
} }
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
use std::sync::Once; use std::sync::Once;
pub use crate::kv_router::protocols::ForwardPassMetrics; pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::KV_METRICS_ENDPOINT; use crate::kv_router::KV_METRICS_ENDPOINT;
use crate::kv_router::scheduler::Endpoint; use crate::kv_router::scheduler::Endpoint;
...@@ -28,6 +28,37 @@ use tokio_util::sync::CancellationToken; ...@@ -28,6 +28,37 @@ use tokio_util::sync::CancellationToken;
static METRICS_WAITING_MESSAGE: Once = Once::new(); static METRICS_WAITING_MESSAGE: Once = Once::new();
static METRICS_FOUND_MESSAGE: Once = Once::new(); static METRICS_FOUND_MESSAGE: Once = Once::new();
pub struct EndpointCollector {
pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
}
impl EndpointCollector {
pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
let (watch_tx, watch_rx) = watch::channel(ProcessedEndpoints::default());
tokio::spawn(collect_endpoints_task(
component.clone(),
watch_tx,
cancellation_token.clone(),
"generate".to_string(),
));
Self {
service_name: component.service_name(),
endpoints_rx: watch_rx,
}
}
pub fn get_endpoints(&self) -> ProcessedEndpoints {
self.endpoints_rx.borrow().clone()
}
pub fn endpoints_watcher(&self) -> watch::Receiver<ProcessedEndpoints> {
self.endpoints_rx.clone()
}
}
pub struct KvMetricsAggregator { pub struct KvMetricsAggregator {
pub service_name: String, pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>, pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
...@@ -41,6 +72,7 @@ impl KvMetricsAggregator { ...@@ -41,6 +72,7 @@ impl KvMetricsAggregator {
component.clone(), component.clone(),
watch_tx, watch_tx,
cancellation_token.clone(), cancellation_token.clone(),
KV_METRICS_ENDPOINT.to_string(),
)); ));
Self { Self {
...@@ -93,12 +125,16 @@ pub async fn collect_endpoints_task( ...@@ -93,12 +125,16 @@ pub async fn collect_endpoints_task(
component: Component, component: Component,
watch_tx: watch::Sender<ProcessedEndpoints>, watch_tx: watch::Sender<ProcessedEndpoints>,
cancel: CancellationToken, cancel: CancellationToken,
subject: String,
) { ) {
let backoff_delay = Duration::from_millis(100); let backoff_delay = Duration::from_millis(100);
let scrape_timeout = Duration::from_millis(300); let scrape_timeout = Duration::from_millis(300);
let endpoint = component.endpoint(KV_METRICS_ENDPOINT); let endpoint = component.endpoint(&subject);
let service_subject = endpoint.subject(); let service_subject = endpoint.subject();
// Keep track of the last sent value to avoid unnecessary updates
let mut last_sent: Option<ProcessedEndpoints> = None;
loop { loop {
tokio::select! { tokio::select! {
_ = cancel.cancelled() => { _ = cancel.cancelled() => {
...@@ -115,30 +151,58 @@ pub async fn collect_endpoints_task( ...@@ -115,30 +151,58 @@ pub async fn collect_endpoints_task(
continue; continue;
} }
}; };
let endpoints: Vec<Endpoint> = unfiltered_endpoints
.into_iter() let endpoints: Vec<Endpoint> = if subject == KV_METRICS_ENDPOINT {
.filter(|s| s.data.is_some()) // Original filtering behavior
.filter_map(|s| unfiltered_endpoints
match s.data.unwrap().decode::<ForwardPassMetrics>() { .into_iter()
Ok(data) => Some(Endpoint { .filter_map(|s| {
name: s.name, s.data?
subject: s.subject, .decode::<ForwardPassMetrics>()
data, .map(|data| Endpoint {
}), name: s.name,
Err(e) => { subject: s.subject,
tracing::debug!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e); data: LoadMetrics::EngineLoadMetrics(data),
None })
} .inspect_err(|e| {
} tracing::warn!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e);
) })
.collect(); .ok()
})
.collect()
} else {
// No filtering - just use default LoadMetrics
unfiltered_endpoints
.into_iter()
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: LoadMetrics::default(),
})
.collect()
};
tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len()); tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len());
let processed = ProcessedEndpoints::new(endpoints); let processed = ProcessedEndpoints::new(endpoints);
if watch_tx.send(processed).is_err() { // Only send if different from last sent value
tracing::trace!("failed to send processed endpoints; shutting down"); // This is necessary because the watch channel does not track changes
break; // https://docs.rs/tokio/latest/tokio/sync/watch/struct.Receiver.html#method.has_changed
let should_send = match &last_sent {
Some(last) => last != &processed,
None => true,
};
if should_send {
tracing::trace!("Endpoints changed, sending update for service: {service_subject}");
if watch_tx.send(processed.clone()).is_err() {
tracing::error!("failed to send processed endpoints; shutting down");
break;
}
last_sent = Some(processed);
} else {
tracing::trace!("Endpoints unchanged, skipping update for service: {service_subject}");
} }
} }
} }
......
...@@ -39,14 +39,14 @@ pub struct WorkerSelectionResult { ...@@ -39,14 +39,14 @@ pub struct WorkerSelectionResult {
pub overlap_blocks: usize, pub overlap_blocks: usize,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct ForwardPassMetrics { pub struct ForwardPassMetrics {
pub worker_stats: WorkerStats, pub worker_stats: WorkerStats,
pub kv_stats: KvStats, pub kv_stats: KvStats,
pub spec_decode_stats: Option<SpecDecodeStats>, pub spec_decode_stats: Option<SpecDecodeStats>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct WorkerStats { pub struct WorkerStats {
// https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models // https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models
pub data_parallel_rank: Option<u32>, pub data_parallel_rank: Option<u32>,
...@@ -55,7 +55,7 @@ pub struct WorkerStats { ...@@ -55,7 +55,7 @@ pub struct WorkerStats {
pub num_requests_waiting: u64, pub num_requests_waiting: u64,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct KvStats { pub struct KvStats {
pub kv_active_blocks: u64, pub kv_active_blocks: u64,
pub kv_total_blocks: u64, pub kv_total_blocks: u64,
...@@ -65,7 +65,34 @@ pub struct KvStats { ...@@ -65,7 +65,34 @@ pub struct KvStats {
pub gpu_prefix_cache_hit_rate: f32, pub gpu_prefix_cache_hit_rate: f32,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct PredictiveLoadMetrics {
pub kv_active_blocks: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum LoadMetrics {
EngineLoadMetrics(ForwardPassMetrics),
PredictiveLoadMetrics(PredictiveLoadMetrics),
}
impl LoadMetrics {
pub fn kv_active_blocks(&self) -> u64 {
match self {
LoadMetrics::EngineLoadMetrics(metrics) => metrics.kv_stats.kv_active_blocks,
LoadMetrics::PredictiveLoadMetrics(metrics) => metrics.kv_active_blocks,
}
}
}
impl Default for LoadMetrics {
fn default() -> Self {
LoadMetrics::PredictiveLoadMetrics(PredictiveLoadMetrics::default())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct SpecDecodeStats { pub struct SpecDecodeStats {
pub num_spec_tokens: Option<u32>, pub num_spec_tokens: Option<u32>,
pub num_drafts: Option<u32>, pub num_drafts: Option<u32>,
......
...@@ -17,16 +17,21 @@ use dynamo_runtime::component::Namespace; ...@@ -17,16 +17,21 @@ use dynamo_runtime::component::Namespace;
use dynamo_runtime::traits::events::EventPublisher; use dynamo_runtime::traits::events::EventPublisher;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use super::protocols::WorkerSelectionResult; use super::protocols::WorkerSelectionResult;
use super::WorkerSelector; use super::WorkerSelector;
use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::protocols::ForwardPassMetrics; use crate::kv_router::indexer::WorkerId;
use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::scoring::ProcessedEndpoints;
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
use crate::kv_router::KvRouterConfig; use crate::kv_router::KvRouterConfig;
use crate::kv_router::KV_HIT_RATE_SUBJECT; use crate::kv_router::KV_HIT_RATE_SUBJECT;
use crate::tokens::TokenBlockSequence;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent { pub struct KVHitRateEvent {
...@@ -49,11 +54,11 @@ pub enum KvSchedulerError { ...@@ -49,11 +54,11 @@ pub enum KvSchedulerError {
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' /// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional) /// is cleaned (not optional)
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Endpoint { pub struct Endpoint {
pub name: String, pub name: String,
pub subject: String, pub subject: String,
pub data: ForwardPassMetrics, pub data: LoadMetrics,
} }
impl Endpoint { impl Endpoint {
...@@ -71,22 +76,30 @@ impl Endpoint { ...@@ -71,22 +76,30 @@ impl Endpoint {
} }
} }
#[derive(Debug)]
pub struct SchedulingResponse {
pub best_worker_id: i64,
pub endpoints_changed: Option<Vec<i64>>,
}
pub struct SchedulingRequest { pub struct SchedulingRequest {
pub isl_tokens: usize, pub isl_tokens: usize,
pub overlap: OverlapScores, pub overlap: OverlapScores,
resp_tx: tokio::sync::oneshot::Sender<i64>, pub potential_blocks: HashMap<i64, usize>,
resp_tx: tokio::sync::oneshot::Sender<SchedulingResponse>,
} }
impl SchedulingRequest { impl SchedulingRequest {
pub fn respond(self, worker_id: i64) { pub fn respond(self, response: SchedulingResponse) {
if self.resp_tx.send(worker_id).is_err() { if self.resp_tx.send(response).is_err() {
tracing::trace!("failed to send response to requestor"); tracing::error!("failed to send response to requestor");
} }
} }
} }
pub struct KvScheduler { pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>, request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
sequences: Arc<Mutex<ActiveSequencesMultiWorker>>,
} }
impl KvScheduler { impl KvScheduler {
...@@ -110,12 +123,18 @@ impl KvScheduler { ...@@ -110,12 +123,18 @@ impl KvScheduler {
} }
}); });
let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new(
block_size as usize,
endpoints.worker_ids(),
)));
// Channel to accept new scheduling requests // Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024); let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
// Background task to handle scheduling requests // Background task to handle scheduling requests
tokio::spawn(async move { tokio::spawn(async move {
let mut request: SchedulingRequest; let mut request: SchedulingRequest;
let mut request_rx = request_rx; let mut request_rx = request_rx;
let mut pending_endpoint_update: Option<Vec<i64>> = None;
tracing::trace!("scheduler background task started"); tracing::trace!("scheduler background task started");
'outer: loop { 'outer: loop {
...@@ -137,30 +156,33 @@ impl KvScheduler { ...@@ -137,30 +156,33 @@ impl KvScheduler {
_ = endpoints_rx.changed() => { _ = endpoints_rx.changed() => {
endpoints = endpoints_rx.borrow_and_update().clone(); endpoints = endpoints_rx.borrow_and_update().clone();
pending_endpoint_update = Some(endpoints.worker_ids());
continue 'outer; continue 'outer;
} }
}; };
loop { loop {
match selector.select_worker(&endpoints, &request, block_size) { match selector.select_worker(&endpoints, &request, block_size) {
Ok(selection) => { Ok(selection) => {
let worker_id = process_worker_selection( if let Err(e) = event_tx.send(KVHitRateEvent {
endpoints.borrow_mut(), worker_id: selection.worker_id,
selection, isl_blocks: selection.required_blocks as usize,
&event_tx, overlap_blocks: selection.overlap_blocks,
); }) {
request.respond(worker_id); tracing::warn!("Failed to send KV hit rate event: {:?}", e);
}
let response = SchedulingResponse {
best_worker_id: selection.worker_id,
endpoints_changed: pending_endpoint_update.take(),
};
request.respond(response);
continue 'outer; continue 'outer;
} }
Err(KvSchedulerError::AllWorkersBusy) => { Err(KvSchedulerError::AllWorkersBusy) => {
tracing::trace!("all workers busy; waiting for more capacity"); tracing::trace!("all workers busy; waiting for more capacity");
match endpoints_rx.changed().await { tokio::time::sleep(Duration::from_millis(5)).await;
Ok(_) => {} continue;
Err(e) => {
tracing::error!("error waiting for endpoints change: {:?}", e);
break 'outer;
}
};
endpoints = endpoints_rx.borrow_and_update().clone();
} }
Err(e) => { Err(e) => {
tracing::error!("error scheduling request: {:?}", e); tracing::error!("error scheduling request: {:?}", e);
...@@ -173,59 +195,81 @@ impl KvScheduler { ...@@ -173,59 +195,81 @@ impl KvScheduler {
tracing::trace!("background endpoint subscriber shutting down"); tracing::trace!("background endpoint subscriber shutting down");
}); });
Ok(KvScheduler { request_tx }) Ok(KvScheduler {
request_tx,
sequences,
})
} }
pub async fn schedule( pub async fn schedule(
&self, &self,
overlap: OverlapScores, request_id: String,
isl_tokens: usize, isl_tokens: usize,
block_size: u32,
tokens: &[u32],
overlap: OverlapScores,
) -> Result<i64, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
let mut sequences = self.sequences.lock().await;
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
let potential_blocks = sequences.potential_blocks(token_sequence);
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
isl_tokens, isl_tokens,
overlap, overlap,
potential_blocks,
resp_tx, resp_tx,
}; };
self.request_tx self.request_tx
.send(request) .send(request)
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
let res = resp_rx let response = resp_rx
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
Ok(res)
if let Some(new_worker_ids) = response.endpoints_changed {
sequences.update_workers(new_worker_ids);
}
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
sequences.add_request(request_id, token_sequence, response.best_worker_id);
Ok(response.best_worker_id)
}
/// Find the potential blocks for each worker if the sequence were routed there
pub async fn potential_blocks(
&self,
token_sequence: TokenBlockSequence,
) -> HashMap<i64, usize> {
let sequences = self.sequences.lock().await;
sequences.potential_blocks(token_sequence)
}
/// Add a new request with its initial tokens to a specific worker
pub async fn add_request(
&self,
request_id: String,
token_sequence: TokenBlockSequence,
worker_id: WorkerId,
) {
let mut sequences = self.sequences.lock().await;
sequences.add_request(request_id, token_sequence, worker_id)
} }
}
// This becomes the driver function that handles the selection result /// Push a token to a specific request's sequence
pub fn process_worker_selection( pub async fn push(&self, request_id: &String, token: u32) {
workers: &mut ProcessedEndpoints, let mut sequences = self.sequences.lock().await;
selection: WorkerSelectionResult, sequences.push(request_id, token)
event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
) -> i64 {
let worker = workers
.endpoints
.get_mut(&selection.worker_id)
.expect("worker not found");
// Update worker state predictively
// Will be overwritten on next polling of metrics
worker.data.kv_stats.kv_active_blocks += selection
.required_blocks
.saturating_sub(selection.overlap_blocks as u64);
// Emit event
if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: selection.worker_id,
isl_blocks: selection.required_blocks as usize,
overlap_blocks: selection.overlap_blocks,
}) {
tracing::warn!("Failed to send KV hit rate event: {:?}", e);
} }
selection.worker_id /// Free all blocks associated with a request
pub async fn free(&self, request_id: &String) {
let mut sequences = self.sequences.lock().await;
sequences.free(request_id)
}
} }
// Helper function for softmax sampling // Helper function for softmax sampling
...@@ -234,6 +278,24 @@ fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 { ...@@ -234,6 +278,24 @@ fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
panic!("Empty logits for softmax sampling"); panic!("Empty logits for softmax sampling");
} }
// Guard: if temperature is 0, return the key with the smallest logit value
if temperature == 0.0 {
// Find the minimum logit value
let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));
// Collect all keys with the minimum logit value (to handle ties)
let min_keys: Vec<_> = logits
.iter()
.filter(|(_, &v)| v == min_logit)
.map(|(k, _)| *k)
.collect();
// Randomly select from the minimum keys (handles single key case naturally)
let mut rng = rand::rng();
let index = rng.random_range(0..min_keys.len());
return min_keys[index];
}
let keys: Vec<_> = logits.keys().copied().collect(); let keys: Vec<_> = logits.keys().copied().collect();
let values: Vec<_> = logits.values().copied().collect(); let values: Vec<_> = logits.values().copied().collect();
...@@ -309,57 +371,42 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -309,57 +371,42 @@ impl WorkerSelector for DefaultWorkerSelector {
} }
let request_blocks = request.isl_tokens.div_ceil(block_size as usize); let request_blocks = request.isl_tokens.div_ceil(block_size as usize);
let potential_active_blocks = &request.potential_blocks;
let mut worker_logits = HashMap::new(); let mut worker_logits = HashMap::new();
let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker // Calculate logits for each worker
for (worker_id, ep) in workers.endpoints.iter() { for (worker_id, _) in workers.endpoints.iter() {
let worker_id = *worker_id; let cached_blocks = request.overlap.scores.get(worker_id).copied().unwrap_or(0) as f64;
let prefill_blocks = request_blocks as f64 - cached_blocks;
// Get overlap blocks for this worker
let overlap_blocks =
request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as f64;
let new_blocks = request_blocks as f64 - overlap_blocks;
let kv_total_blocks = ep.data.kv_stats.kv_total_blocks as f64;
assert!(kv_total_blocks > 0.0);
let normalized_new_blocks = new_blocks / kv_total_blocks; // this is the number of blocks each worker would have if the request were scheduled there
let gpu_cache_usage = ep.data.kv_stats.gpu_cache_usage_perc as f64; let potential_blocks = *potential_active_blocks.get(worker_id).unwrap_or_else(||
let num_requests_waiting = ep.data.worker_stats.num_requests_waiting as f64; {tracing::warn!("assuming 0 decoding blocks for {worker_id}, as the load metrics endpoint does not exist yet");
&0
}) as f64;
// Calculate logit (lower is better) // Calculate logit (lower is better)
let logit = self.kv_router_config.overlap_score_weight * normalized_new_blocks let logit =
+ self.kv_router_config.gpu_cache_usage_weight * gpu_cache_usage self.kv_router_config.overlap_score_weight * prefill_blocks + potential_blocks;
+ self.kv_router_config.waiting_requests_weight * num_requests_waiting; max_logit = max_logit.max(logit);
worker_logits.insert(worker_id, logit); worker_logits.insert(*worker_id, logit);
tracing::info!( tracing::info!(
"Formula for {worker_id}: {logit:.3} = {:.1} * {normalized_new_blocks:.3} + {:.1} * {gpu_cache_usage:.3} + {:.1} * {num_requests_waiting:.3}", "Formula for {worker_id}: {logit:.3} = {:.1} * {prefill_blocks:.3} + {potential_blocks:.3}",
self.kv_router_config.overlap_score_weight, self.kv_router_config.overlap_score_weight,
self.kv_router_config.gpu_cache_usage_weight,
self.kv_router_config.waiting_requests_weight,
); );
} }
// Return early if no valid workers found // Normalize by dividing by max value
if worker_logits.is_empty() || worker_logits.values().all(|&v| v == 0.0) { for logit in worker_logits.values_mut() {
tracing::warn!("All worker logits are zero. Fallback to random routing."); *logit /= max_logit;
// Pick random worker
let mut rng = rand::rng();
let worker_ids: Vec<_> = workers.endpoints.keys().copied().collect();
let worker_id = worker_ids[rng.random_range(0..worker_ids.len())];
let overlap_blocks =
request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;
return Ok(WorkerSelectionResult {
worker_id,
required_blocks: request_blocks as u64,
overlap_blocks,
});
} }
// Use softmax sampling to select worker // Use softmax sampling to select worker
let temperature = 1.0; // You can make this configurable if needed let temperature = self.kv_router_config.router_temperature;
let best_worker_id = softmax_sample(&worker_logits, temperature); let best_worker_id = softmax_sample(&worker_logits, temperature);
let overlap_blocks = request let overlap_blocks = request
...@@ -371,7 +418,7 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -371,7 +418,7 @@ impl WorkerSelector for DefaultWorkerSelector {
let best_logit = worker_logits[&best_worker_id]; let best_logit = worker_logits[&best_worker_id];
tracing::info!( tracing::info!(
"Selected worker: {}, logit: {:.3}", "Selected worker: {}, normalized logit: {:.3}",
best_worker_id, best_worker_id,
best_logit best_logit
); );
...@@ -387,7 +434,6 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -387,7 +434,6 @@ impl WorkerSelector for DefaultWorkerSelector {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::kv_router::protocols::{KvStats, WorkerStats};
#[test] #[test]
fn test_softmax_sample_single_key() { fn test_softmax_sample_single_key() {
...@@ -416,191 +462,31 @@ mod tests { ...@@ -416,191 +462,31 @@ mod tests {
assert_eq!(softmax_sample(&logits, 1.0), worker_id); assert_eq!(softmax_sample(&logits, 1.0), worker_id);
} }
// Helper to create a worker endpoint #[test]
fn create_endpoint( fn test_softmax_sample_zero_temperature() {
worker_id: i64, // Test that with temperature 0, softmax_sample returns the key with smallest logit
gpu_cache_usage_perc: f32, let mut logits = HashMap::new();
num_requests_waiting: u64, logits.insert(1, 5.0);
) -> Endpoint { logits.insert(2, 3.0); // This has the smallest logit
Endpoint { logits.insert(3, 7.0);
name: format!("worker-{}", worker_id), logits.insert(4, 3.5);
subject: format!("worker-subject-{:x}", worker_id),
data: ForwardPassMetrics { // With temperature 0, should always return worker 2 (smallest logit)
kv_stats: KvStats { for _ in 0..10 {
gpu_cache_usage_perc, let result = softmax_sample(&logits, 0.0);
..Default::default() assert_eq!(
}, result, 2,
worker_stats: WorkerStats { "Should return worker with smallest logit when temperature is 0"
num_requests_waiting,
..Default::default()
},
// Other fields can be default initialized for this test
..Default::default()
},
}
}
// Helper to create ProcessedEndpoints
struct WorkerInfo {
id: i64,
usage: f32,
waiting: u64,
}
fn create_workers(workers: Vec<WorkerInfo>) -> ProcessedEndpoints {
let mut endpoints = HashMap::new();
for worker in workers {
endpoints.insert(
worker.id,
create_endpoint(worker.id, worker.usage, worker.waiting),
); );
} }
ProcessedEndpoints {
endpoints,
load_avg: 0.0,
load_std: 0.0,
}
}
// Helper to create a scheduling request // Test with negative values
struct WorkerOverlap { logits.clear();
worker_id: i64, logits.insert(10, -1.0);
overlap_blocks: u32, logits.insert(20, -5.0); // This has the smallest logit
} logits.insert(30, 0.0);
fn create_request(overlaps: Vec<WorkerOverlap>, isl_tokens: usize) -> SchedulingRequest {
SchedulingRequest {
isl_tokens,
overlap: OverlapScores {
scores: overlaps
.into_iter()
.map(|wo| (wo.worker_id, wo.overlap_blocks))
.collect(),
frequencies: vec![],
},
resp_tx: tokio::sync::oneshot::channel().0,
}
}
#[test] let result = softmax_sample(&logits, 0.0);
fn test_no_endpoints() { assert_eq!(result, 20, "Should handle negative logits correctly");
let workers = create_workers(vec![]);
let request = create_request(vec![], 100);
let selector = DefaultWorkerSelector::new(None);
let block_size = 20;
match selector.select_worker(&workers, &request, block_size) {
Err(KvSchedulerError::NoEndpoints) => {} // Expected
_ => panic!("Should return NoEndpoints error"),
}
} }
// #[test]
// fn test_select_worker_basic() {
// // Setup workers
// let workers = create_workers(vec![
// WorkerInfo {
// id: 1,
// usage: 0.50,
// waiting: 1,
// },
// WorkerInfo {
// id: 2,
// usage: 0.80,
// waiting: 0,
// },
// ]);
// // Setup request: 100 tokens, block_size=20 (5 blocks)
// let request = create_request(
// vec![
// WorkerOverlap {
// worker_id: 1,
// overlap_blocks: 3,
// },
// WorkerOverlap {
// worker_id: 2,
// overlap_blocks: 4,
// },
// ],
// 100,
// );
// let selector = DefaultWorkerSelector::new(None);
// let block_size = 20;
// // Execute selection
// let result = selector
// .select_worker(&workers, &request, block_size)
// .expect("Should select a worker");
// // Worker 2 should win because:
// // Worker1: 2.0 * 0.600 - 1.0 * 0.500 - 1.0 * 1.000 = -0.3
// // Worker2: 2.0 * 0.800 - 1.0 * 0.800 - 1.0 * 0.000 = 0.8
// assert_eq!(result.worker_id, 2);
// assert_eq!(result.required_blocks, 5); // 100 tokens / 20 block_size
// assert_eq!(result.overlap_blocks, 4);
// }
// #[test]
// fn test_no_overlap_scores() {
// // Workers exist but request has no overlap scores
// let workers = create_workers(vec![WorkerInfo {
// id: 1,
// usage: 0.50,
// waiting: 1,
// }]);
// let request = create_request(vec![], 100); // No overlaps
// let selector = DefaultWorkerSelector::new(None);
// let block_size = 20;
// let result = selector
// .select_worker(&workers, &request, block_size)
// .expect("Should fallback to selecting worker");
// // Worker1 should be selected with 0 overlap
// assert_eq!(result.worker_id, 1);
// assert_eq!(result.overlap_blocks, 0);
// }
// #[test]
// fn test_custom_weights() {
// // Setup workers
// let workers = create_workers(vec![
// WorkerInfo {
// id: 1,
// usage: 0.50,
// waiting: 1,
// },
// WorkerInfo {
// id: 2,
// usage: 0.80,
// waiting: 0,
// },
// ]);
// // Custom config with high priority on GPU usage
// let config = KvRouterConfig {
// gpu_cache_usage_weight: 10.0, // Very high weight
// overlap_score_weight: 2.0, // just current defaults
// waiting_requests_weight: 1.0,
// };
// let selector = DefaultWorkerSelector::new(Some(config));
// let request = create_request(
// vec![
// WorkerOverlap {
// worker_id: 1,
// overlap_blocks: 3,
// },
// WorkerOverlap {
// worker_id: 2,
// overlap_blocks: 4,
// },
// ],
// 100,
// );
// let block_size = 20;
// let result = selector
// .select_worker(&workers, &request, block_size)
// .expect("Should select worker");
// assert_eq!(result.worker_id, 1);
// }
} }
...@@ -20,7 +20,7 @@ use std::collections::HashMap; ...@@ -20,7 +20,7 @@ use std::collections::HashMap;
use crate::kv_router::scheduler::Endpoint; use crate::kv_router::scheduler::Endpoint;
#[derive(Debug, Default, Serialize, Deserialize, Clone)] #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct ProcessedEndpoints { pub struct ProcessedEndpoints {
pub endpoints: HashMap<i64, Endpoint>, pub endpoints: HashMap<i64, Endpoint>,
pub load_avg: f64, pub load_avg: f64,
...@@ -32,7 +32,7 @@ impl ProcessedEndpoints { ...@@ -32,7 +32,7 @@ impl ProcessedEndpoints {
// compute some basic statistics // compute some basic statistics
let load_values: Vec<f64> = endpoints let load_values: Vec<f64> = endpoints
.iter() .iter()
.map(|x| x.data.kv_stats.kv_active_blocks as f64) .map(|endpoint| endpoint.data.kv_active_blocks() as f64)
.collect(); .collect();
let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64; let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64;
let variance = load_values let variance = load_values
...@@ -50,4 +50,15 @@ impl ProcessedEndpoints { ...@@ -50,4 +50,15 @@ impl ProcessedEndpoints {
load_std, load_std,
} }
} }
pub fn worker_ids(&self) -> Vec<i64> {
self.endpoints.keys().copied().collect()
}
pub fn active_blocks(&self) -> HashMap<i64, usize> {
self.endpoints
.iter()
.map(|(&worker_id, endpoint)| (worker_id, endpoint.data.kv_active_blocks() as usize))
.collect()
}
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! KV Cache Sequence Management for LLM Inference
//!
//! This module provides efficient management of token sequences and their associated KV cache blocks
//! for distributed LLM inference. It implements a shared block system where multiple requests can
//! reuse the same KV cache blocks for common token prefixes, significantly reducing memory usage.
//!
//! # Key Components
//!
//! - [`ActiveSequences`]: Single-threaded sequence manager that tracks active requests and their
//! token sequences, managing shared KV cache blocks efficiently.
//!
//! - [`ActiveSequencesMultiWorker`]: Multi-threaded extension that distributes sequence management
//! across multiple worker threads, enabling parallel processing of requests while maintaining
//! consistency.
//!
//! # Architecture
//!
//! The system uses a block-based approach where token sequences are divided into fixed-size blocks.
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).
use crate::kv_router::indexer::WorkerId;
use crate::tokens::blocks::UniqueBlock;
use crate::tokens::TokenBlockSequence;
use derive_getters::Getters;
use std::collections::{HashMap, HashSet};
use std::sync::{mpsc, Arc};
use std::thread;
use std::time::Duration;
use uuid;
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;
/// Create unique blocks from a TokenBlockSequence
fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: usize,
) -> Vec<UniqueBlock> {
let mut unique_blocks: Vec<UniqueBlock> = tokens
.blocks()
.iter()
.map(|block| UniqueBlock::FullBlock(block.sequence_hash()))
.collect();
// Only push the partial block if tokens count isn't a multiple of block_size
if tokens.total_tokens() % block_size != 0 {
unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
});
}
unique_blocks
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
active_seqs: HashMap<RequestId, TokenBlockSequence>,
partial_blocks: HashMap<RequestId, UniqueBlock>,
unique_blocks: HashMap<UniqueBlock, HashSet<RequestId>>,
#[getter(copy)]
block_size: usize,
#[getter(copy)]
active_blocks: usize,
}
impl ActiveSequences {
/// Create a new SharedSequenceManager instance
pub fn new(block_size: usize) -> Self {
// TODO: make this not a hard req
assert!(block_size > 1, "block_size must be greater than 1");
Self {
active_seqs: HashMap::new(),
partial_blocks: HashMap::new(),
unique_blocks: HashMap::new(),
block_size,
active_blocks: 0,
}
}
fn add_block(&mut self, request_id: RequestId, block: &UniqueBlock) {
let is_new_block = !self.unique_blocks.contains_key(block);
self.unique_blocks
.entry(block.clone())
.or_default()
.insert(request_id.clone());
if is_new_block {
self.active_blocks += 1;
}
if matches!(block, UniqueBlock::PartialBlock(_)) {
self.partial_blocks.insert(request_id, block.clone());
};
}
fn remove_block(&mut self, request_id: &RequestId, block: &UniqueBlock) {
let Some(request_ids) = self.unique_blocks.get_mut(block) else {
panic!("Cannot remove a block that does not exist.")
};
// Remove the unique block if no more requests using it
request_ids.retain(|w| w != request_id);
if request_ids.is_empty() {
self.active_blocks -= 1;
self.unique_blocks.remove(block);
}
}
/// Add a new request with its initial tokens
pub fn add_request(
&mut self,
request_id: RequestId,
token_sequence: TokenBlockSequence,
) -> usize {
let blocks = create_unique_blocks_from_sequence(&token_sequence, None, self.block_size);
for block in &blocks {
self.add_block(request_id.clone(), block);
}
self.active_seqs.insert(request_id.clone(), token_sequence);
self.active_blocks
}
/// Match a request against existing blocks and return the number of new blocks that would be added
pub fn new_blocks(&self, token_sequence: &TokenBlockSequence) -> usize {
let blocks = create_unique_blocks_from_sequence(token_sequence, None, self.block_size);
blocks
.iter()
.filter(|block| !self.unique_blocks.contains_key(block))
.count()
}
/// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &TokenBlockSequence) -> usize {
self.new_blocks(token_sequence) + self.active_blocks
}
/// Free all blocks associated with a request
pub fn free(&mut self, request_id: &RequestId) -> usize {
let Some(token_seq) = self.active_seqs.get(request_id) else {
tracing::warn!("Trying to free free non-existent request {request_id}");
return 0;
};
let blocks = create_unique_blocks_from_sequence(token_seq, None, self.block_size);
for block in blocks {
if matches!(block, UniqueBlock::FullBlock(_)) {
self.remove_block(request_id, &block);
}
}
if let Some(partial_block) = self.partial_blocks.remove(request_id) {
self.remove_block(request_id, &partial_block);
}
self.active_seqs.remove(request_id).unwrap();
self.active_blocks
}
/// Push a token to a specific request's sequence
pub fn push(&mut self, request_id: &RequestId, token: u32) -> usize {
let token_seq = self
.active_seqs
.get_mut(request_id)
.expect("Request ID not found for token push");
token_seq.append(token).expect("Token push failed.");
// No need to update anything
if token_seq.total_tokens() % self.block_size != 1 {
return self.active_blocks;
}
let last_seq_hash = token_seq
.last_complete_block()
.map(|block| block.sequence_hash());
// Promote a partial block into a full block if not already
if let Some(partial_block) = self.partial_blocks.get(request_id).cloned() {
self.remove_block(request_id, &partial_block);
}
if let Some(full_block) = last_seq_hash {
self.add_block(request_id.clone(), &UniqueBlock::FullBlock(full_block));
}
self.add_block(request_id.clone(), &UniqueBlock::default());
self.active_blocks
}
}
#[derive(Debug)]
enum UpdateSequences {
AddRequest {
request_id: RequestId,
token_sequence: TokenBlockSequence,
},
Free {
request_id: RequestId,
},
Push {
request_id: RequestId,
token: u32,
},
NewBlocks {
token_sequence: Arc<TokenBlockSequence>,
resp_tx: mpsc::SyncSender<usize>,
},
PotentialBlocks {
token_sequence: Arc<TokenBlockSequence>,
resp_tx: mpsc::SyncSender<usize>,
},
ActiveBlocks {
resp_tx: mpsc::SyncSender<usize>,
},
Shutdown,
}
/// Multi-worker extension of ActiveSequences that distributes requests across multiple threads
pub struct ActiveSequencesMultiWorker {
senders: HashMap<WorkerId, mpsc::Sender<UpdateSequences>>,
request_to_worker: HashMap<RequestId, WorkerId>,
handles: HashMap<WorkerId, thread::JoinHandle<()>>,
block_size: usize,
}
impl ActiveSequencesMultiWorker {
pub fn new(block_size: usize, worker_ids: Vec<WorkerId>) -> Self {
assert!(block_size > 1, "block_size must be greater than 1");
let mut senders = HashMap::new();
let mut handles = HashMap::new();
for worker_id in worker_ids {
let (sender, handle) = Self::start_worker(block_size);
senders.insert(worker_id, sender);
handles.insert(worker_id, handle);
}
Self {
senders,
request_to_worker: HashMap::new(),
handles,
block_size,
}
}
/// Helper method to start a worker thread
fn start_worker(block_size: usize) -> (mpsc::Sender<UpdateSequences>, thread::JoinHandle<()>) {
let (request_tx, request_rx) = mpsc::channel::<UpdateSequences>();
let handle = thread::spawn(move || {
let mut active_sequences = ActiveSequences::new(block_size);
while let Ok(command) = request_rx.recv() {
match command {
UpdateSequences::AddRequest {
request_id,
token_sequence,
} => {
active_sequences.add_request(request_id, token_sequence);
}
UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id);
}
UpdateSequences::Push { request_id, token } => {
active_sequences.push(&request_id, token);
}
UpdateSequences::NewBlocks {
token_sequence,
resp_tx,
} => {
let new_blocks = active_sequences.new_blocks(&token_sequence);
let _ = resp_tx.send(new_blocks);
}
UpdateSequences::PotentialBlocks {
token_sequence,
resp_tx,
} => {
let potential_blocks = active_sequences.potential_blocks(&token_sequence);
let _ = resp_tx.send(potential_blocks);
}
UpdateSequences::ActiveBlocks { resp_tx } => {
let active_blocks = active_sequences.active_blocks();
let _ = resp_tx.send(active_blocks);
}
UpdateSequences::Shutdown => {
break;
}
}
}
});
(request_tx, handle)
}
/// Update the set of workers, adding and removing as needed
pub fn update_workers(&mut self, new_worker_ids: Vec<WorkerId>) -> HashMap<WorkerId, usize> {
let current_workers: HashSet<WorkerId> = self.senders.keys().copied().collect();
let new_workers: HashSet<WorkerId> = new_worker_ids.into_iter().collect();
let workers_to_remove: Vec<WorkerId> =
current_workers.difference(&new_workers).copied().collect();
let workers_to_add: Vec<WorkerId> =
new_workers.difference(&current_workers).copied().collect();
// Remove workers
for worker_id in &workers_to_remove {
tracing::warn!("Removing worker {}", worker_id);
// Send shutdown command to the worker
if let Some(sender) = self.senders.remove(worker_id) {
let _ = sender.send(UpdateSequences::Shutdown);
}
if let Some(handle) = self.handles.remove(worker_id) {
let _ = handle.join();
}
}
// Add new workers
for worker_id in &workers_to_add {
tracing::warn!("Adding worker {}", worker_id);
let (sender, handle) = Self::start_worker(self.block_size);
self.senders.insert(*worker_id, sender);
self.handles.insert(*worker_id, handle);
}
// Return active blocks for all workers
self.active_blocks()
}
pub fn add_request(
&mut self,
request_id: RequestId,
token_sequence: TokenBlockSequence,
worker_id: WorkerId,
) {
if !self.senders.contains_key(&worker_id) {
panic!("Worker ID {worker_id} not found");
}
self.request_to_worker.insert(request_id.clone(), worker_id);
self.senders[&worker_id]
.send(UpdateSequences::AddRequest {
request_id,
token_sequence,
})
.expect("Failed to send add_request command to worker");
}
pub fn free(&mut self, request_id: &RequestId) {
let worker_id = self
.request_to_worker
.get(request_id)
.copied()
.expect("Request ID not found in request_to_worker mapping");
self.senders[&worker_id]
.send(UpdateSequences::Free {
request_id: request_id.clone(),
})
.expect("Failed to send free command to worker");
self.request_to_worker.remove(request_id);
}
pub fn push(&mut self, request_id: &RequestId, token: u32) {
let worker_id = self
.request_to_worker
.get(request_id)
.copied()
.expect("Request ID not found in request_to_worker mapping");
self.senders[&worker_id]
.send(UpdateSequences::Push {
request_id: request_id.clone(),
token,
})
.expect("Failed to send push command to worker");
}
/// Get the number of workers
pub fn num_workers(&self) -> usize {
self.senders.len()
}
/// Generic method to query all workers with a given command
fn query_workers(
&self,
token_sequence: Option<TokenBlockSequence>,
command_fn: impl Fn(Option<Arc<TokenBlockSequence>>, mpsc::SyncSender<usize>) -> UpdateSequences,
) -> HashMap<WorkerId, usize> {
let mut results = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new();
// Send queries to all workers in parallel
for (worker_id, sender) in &self.senders {
let (resp_tx, resp_rx) = mpsc::sync_channel(0);
receivers.push((worker_id, resp_rx));
sender
.send(command_fn(token_sequence_shared.clone(), resp_tx))
.expect("Failed to send command to worker");
}
// Collect results from all workers
for (worker_id, receiver) in receivers {
let result = receiver
.recv_timeout(Duration::from_secs(1))
.expect("Failed to receive response from worker");
results.insert(*worker_id, result);
}
results
}
/// Query all workers for the number of new blocks that would be added by a token sequence
pub fn new_blocks(&self, token_sequence: TokenBlockSequence) -> HashMap<WorkerId, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::NewBlocks {
token_sequence: ts,
resp_tx,
},
None => unreachable!("token_sequence should always be Some for new_blocks"),
})
}
/// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
pub fn potential_blocks(&self, token_sequence: TokenBlockSequence) -> HashMap<WorkerId, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::PotentialBlocks {
token_sequence: ts,
resp_tx,
},
None => unreachable!("token_sequence should always be Some for potential_blocks"),
})
}
/// Query all workers for their current number of active blocks
pub fn active_blocks(&self) -> HashMap<WorkerId, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
}
}
impl Drop for ActiveSequencesMultiWorker {
fn drop(&mut self) {
// Send shutdown command to all workers
for sender in self.senders.values() {
let _ = sender.send(UpdateSequences::Shutdown);
}
// Wait for all threads to finish
for (_, handle) in self.handles.drain() {
let _ = handle.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokens::Tokens;
#[test]
fn test_shared_sequence_manager_operations() {
let block_size = 4;
let mut manager = ActiveSequences::new(block_size);
let to_sequence =
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2]));
manager.push(&"0".to_string(), 3);
manager.push(&"0".to_string(), 4);
assert_eq!(manager.active_blocks(), 2);
assert_eq!(manager.partial_blocks.len(), 1);
// Step 2: Add request 1 with tokens [0, 1, 2, 3, 4, 5, 6]
manager.add_request("1".to_string(), to_sequence(vec![0, 1, 2, 3, 4, 5, 6]));
assert_eq!(manager.active_blocks(), 3);
// Check that only one key is FullBlock with both requests sharing it
let mut full_block_count = 0;
let mut shared_block_requests = None;
for (block, requests) in &manager.unique_blocks {
if let UniqueBlock::FullBlock(_) = block {
full_block_count += 1;
if requests.len() == 2 {
shared_block_requests = Some(requests.clone());
}
}
}
assert_eq!(full_block_count, 1);
assert!(shared_block_requests.is_some());
let shared_requests = shared_block_requests.unwrap();
assert!(shared_requests.contains("0"));
assert!(shared_requests.contains("1"));
let new_blocks = manager.new_blocks(&to_sequence(vec![0, 1, 2, 3, 4, 5]));
assert_eq!(new_blocks, 1);
// Step 3: Free request 1
manager.free(&"1".to_string());
assert_eq!(manager.active_blocks(), 2);
// Step 4: Free request 0
manager.free(&"0".to_string());
assert_eq!(manager.active_blocks(), 0);
assert_eq!(manager.unique_blocks.len(), 0);
assert_eq!(manager.partial_blocks.len(), 0);
assert_eq!(manager.active_seqs.len(), 0);
}
#[test]
fn test_active_sequences_multi_worker() {
let block_size = 4;
let worker_ids = vec![0, 1, 2];
let mut manager = ActiveSequencesMultiWorker::new(block_size, worker_ids);
let to_sequence =
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Send request [0, 1, 2, 3] to worker 0
manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0);
// Send request [0, 1, 2] to worker 1, then push 3 and push 4
manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 1);
manager.push(&"req1".to_string(), 3);
manager.push(&"req1".to_string(), 4);
// Send request [0, 1, 2] to worker 2
manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 2);
// Check new_blocks on tokens [0, 1, 2, 3, 4]
let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4]));
assert_eq!(new_blocks_map[&0], 1); // Worker 0 would have 1 new block
assert_eq!(new_blocks_map[&1], 1); // Worker 1 would have 1 new block
assert_eq!(new_blocks_map[&2], 2); // Worker 2 would have 2 new blocks
manager.update_workers(vec![0, 1]);
manager.update_workers(vec![0, 1, 3]);
let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4]));
assert_eq!(new_blocks_map.len(), 3);
assert_eq!(new_blocks_map[&3], 2);
}
}
...@@ -46,8 +46,9 @@ ...@@ -46,8 +46,9 @@
//! implementation of the main block manager. //! implementation of the main block manager.
use crate::mocker::evictor::LRUEvictor; use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost, UniqueBlock}; use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost};
use crate::mocker::sequence::ActiveSequence; use crate::mocker::sequence::ActiveSequence;
use crate::tokens::blocks::UniqueBlock;
use derive_getters::Getters; use derive_getters::Getters;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use tokio::sync::mpsc; use tokio::sync::mpsc;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment