"tests/vscode:/vscode.git/clone" did not exist on "e7544f19679685d9ab50cbcf61691787e32041ac"
Unverified Commit 84e71e27 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: predictive active blocks for routing without load metrics (#1731)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarAlec <35311602+alec-flowers@users.noreply.github.com>
parent ffccc722
......@@ -83,7 +83,7 @@ use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::time::Duration as StdDuration;
use dynamo_llm::kv_router::protocols::ForwardPassMetrics;
use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics};
use dynamo_llm::kv_router::scheduler::Endpoint;
use dynamo_llm::kv_router::scoring::ProcessedEndpoints;
......@@ -449,7 +449,10 @@ impl PrometheusMetrics {
// Update per-worker metrics
for (worker_id, endpoint) in processed.endpoints.iter() {
let worker_id = worker_id.to_string();
let metrics = endpoint.data.clone();
let load_metrics = endpoint.data.clone();
let LoadMetrics::EngineLoadMetrics(metrics) = load_metrics else {
panic!("Can only update with ForwardPassMetrics");
};
self.set_worker_gauge(
&self.kv_blocks_active,
......@@ -602,7 +605,7 @@ pub fn postprocess_metrics(
e.id().ok().map(|id| Endpoint {
name: format!("worker-{id}"),
subject: e.subject.clone(),
data: m.clone(),
data: LoadMetrics::EngineLoadMetrics(m.clone()),
})
})
.collect();
......
......@@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> {
let selector = Box::new(CustomWorkerSelector::default());
let router = KvRouter::new(component.clone(), args.block_size, Some(selector)).await?;
let router = KvRouter::new(component.clone(), args.block_size, Some(selector), true).await?;
let router = Ingress::for_engine(Arc::new(router))?;
component
......
......@@ -3392,14 +3392,8 @@ index cafd8150b..6a5e45b4e 100644
+ num_requests_waiting: int
+ gpu_cache_usage_perc: float
+ gpu_prefix_cache_hit_rate: float
+ spec_decode_draft_acceptance_rate: Optional[float] = None
+ spec_decode_system_efficiency: Optional[float] = None
+ spec_decode_draft_tokens: Optional[int] = None
+ spec_decode_emitted_tokens: Optional[int] = None
+ spec_decode_accepted_tokens: Optional[int] = None
+ spec_decode_num_spec_tokens: Optional[int] = None
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index f058b1329..2fdb5b8bf 100644
index f058b1329..fd5610a3c 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -1,4 +1,17 @@
......@@ -3460,16 +3454,25 @@ index f058b1329..2fdb5b8bf 100644
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
@@ -48,6 +66,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -48,6 +66,17 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import Device, deprecate_kwargs
+from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback
+from vllm.distributed.device_communicators.nixl import NixlMetadata
+
+# Import ForwardPassMetrics and related classes from dynamo
+try:
+ from dynamo.llm import ForwardPassMetrics, WorkerStats, KvStats
+except ImportError:
+ # Fallback if dynamo imports are not available
+ ForwardPassMetrics = None
+ WorkerStats = None
+ KvStats = None
logger = init_logger(__name__)
@@ -93,6 +113,7 @@ class MQLLMEngineClient(EngineClient):
@@ -93,6 +122,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None
# Get the configs.
......@@ -3477,7 +3480,7 @@ index f058b1329..2fdb5b8bf 100644
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
@@ -117,6 +138,10 @@ class MQLLMEngineClient(EngineClient):
@@ -117,6 +147,10 @@ class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
......@@ -3488,7 +3491,7 @@ index f058b1329..2fdb5b8bf 100644
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -131,8 +156,27 @@ class MQLLMEngineClient(EngineClient):
@@ -131,8 +165,27 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
......@@ -3516,7 +3519,7 @@ index f058b1329..2fdb5b8bf 100644
@staticmethod
def is_unsupported_config(vllm_config: VllmConfig):
# Pipeline parallel not yet supported
@@ -182,6 +226,61 @@ class MQLLMEngineClient(EngineClient):
@@ -182,6 +235,76 @@ class MQLLMEngineClient(EngineClient):
except Exception as e:
self._set_errored(e)
......@@ -3553,13 +3556,28 @@ index f058b1329..2fdb5b8bf 100644
+ if self.metrics_publisher is not None and isinstance(
+ metrics, KvMetrics
+ ):
+ self.metrics_publisher.publish(metrics.request_active_slots,
+ metrics.request_total_slots,
+ metrics.kv_active_blocks,
+ metrics.kv_total_blocks,
+ metrics.num_requests_waiting,
+ metrics.gpu_cache_usage_perc,
+ metrics.gpu_prefix_cache_hit_rate)
+ # Construct structured metrics objects
+ worker_stats = WorkerStats(
+ request_active_slots=metrics.request_active_slots,
+ request_total_slots=metrics.request_total_slots,
+ num_requests_waiting=metrics.num_requests_waiting,
+ data_parallel_rank=None
+ )
+
+ kv_stats = KvStats(
+ kv_active_blocks=metrics.kv_active_blocks,
+ kv_total_blocks=metrics.kv_total_blocks,
+ gpu_cache_usage_perc=metrics.gpu_cache_usage_perc,
+ gpu_prefix_cache_hit_rate=metrics.gpu_prefix_cache_hit_rate
+ )
+
+ forward_pass_metrics = ForwardPassMetrics(
+ worker_stats=worker_stats,
+ kv_stats=kv_stats,
+ spec_decode_stats=None
+ )
+
+ self.metrics_publisher.publish(forward_pass_metrics)
+ logger.debug("Metrics successful.")
+
+ # TODO: Investigate sending whole stats object
......@@ -3578,7 +3596,7 @@ index f058b1329..2fdb5b8bf 100644
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
@@ -250,7 +349,7 @@ class MQLLMEngineClient(EngineClient):
@@ -250,7 +373,7 @@ class MQLLMEngineClient(EngineClient):
# Put each output into the appropriate queue.
elif isinstance(
request_outputs,
......@@ -3587,7 +3605,7 @@ index f058b1329..2fdb5b8bf 100644
self._add_output(request_outputs)
else:
for request_output in request_outputs:
@@ -261,7 +360,7 @@ class MQLLMEngineClient(EngineClient):
@@ -261,7 +384,7 @@ class MQLLMEngineClient(EngineClient):
def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse,
......@@ -3596,7 +3614,7 @@ index f058b1329..2fdb5b8bf 100644
queue = self.output_queues.get(request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
@@ -283,12 +382,25 @@ class MQLLMEngineClient(EngineClient):
@@ -283,12 +406,25 @@ class MQLLMEngineClient(EngineClient):
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
......@@ -3622,7 +3640,7 @@ index f058b1329..2fdb5b8bf 100644
def close(self):
"""Destroy the ZeroMQ Context."""
@@ -298,6 +410,8 @@ class MQLLMEngineClient(EngineClient):
@@ -298,6 +434,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
......@@ -3631,7 +3649,7 @@ index f058b1329..2fdb5b8bf 100644
if self.output_loop is not None:
self.output_loop.cancel()
@@ -420,6 +534,9 @@ class MQLLMEngineClient(EngineClient):
@@ -420,6 +558,9 @@ class MQLLMEngineClient(EngineClient):
"""
if self._errored_with is not None:
raise self._errored_with
......@@ -3641,7 +3659,7 @@ index f058b1329..2fdb5b8bf 100644
@property
def is_running(self) -> bool:
@@ -478,6 +595,7 @@ class MQLLMEngineClient(EngineClient):
@@ -478,6 +619,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
......@@ -3649,7 +3667,7 @@ index f058b1329..2fdb5b8bf 100644
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
@@ -507,7 +625,8 @@ class MQLLMEngineClient(EngineClient):
@@ -507,7 +649,8 @@ class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
......@@ -3659,7 +3677,7 @@ index f058b1329..2fdb5b8bf 100644
@overload
def encode(
@@ -591,6 +710,7 @@ class MQLLMEngineClient(EngineClient):
@@ -591,6 +734,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
......@@ -3667,7 +3685,7 @@ index f058b1329..2fdb5b8bf 100644
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -636,6 +756,12 @@ class MQLLMEngineClient(EngineClient):
@@ -636,6 +780,12 @@ class MQLLMEngineClient(EngineClient):
else:
lp_bytes = None
......@@ -3680,7 +3698,7 @@ index f058b1329..2fdb5b8bf 100644
request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
@@ -645,11 +771,11 @@ class MQLLMEngineClient(EngineClient):
@@ -645,11 +795,11 @@ class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
......@@ -3694,7 +3712,7 @@ index f058b1329..2fdb5b8bf 100644
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
@@ -740,3 +866,22 @@ class MQLLMEngineClient(EngineClient):
@@ -740,3 +890,22 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
......
......@@ -15,6 +15,8 @@ See the License for the specific language governing permissions and
limitations under the License.
-->
>[!NOTE]
>This information is temporary and will change soon.
# KV Cache Routing
This documentation explains how Key-Value (KV) cache routing works in Dynamo, providing optimized inference for large language models by intelligently directing requests to workers with the most relevant cached data while simultaneously load balancing based on utilization metrics sent by the workers.
......
......@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
Usage:
```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=2.0] [--kv-gpu-cache-usage-weight=1.0] [--kv-waiting-requests-weight=1.0] [--verbosity (-v|-vv)]
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--verbosity (-v|-vv)]
```
Example: `dynamo run Qwen/Qwen3-0.6B`
......@@ -201,6 +201,8 @@ The only difference from the distributed system above is `--router-mode kv`. The
For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing.
The argument `--kv-overlap-score-weight` sets the amount weighting on overlaps with prefix caches, which directly contributes to the prefill cost, so a large weight is expected to yield a better TTFT (at the expense of worse ITL). When this is set 0, we do not consider the prefix caches at all (falling back to pure load balancing behavior on the active blocks), in which case we do not require the backend engines to emit any KV events. The argument `--router-temperature` sets the temperature when randomly selecting the workers to route to via softmax sampling on the router cost logits, setting it to 0 recovers the deterministic behavior where the min logit is picked.
## Full usage details
`dynamo run` executes `dynamo-run`. `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features.
......
......@@ -15,6 +15,9 @@ See the License for the specific language governing permissions and
limitations under the License.
-->
>[!NOTE]
>This information is temporary and will change soon.
# KV Router Performance Tuning
## Overview
......
......@@ -110,20 +110,23 @@ pub struct Flags {
#[arg(long, default_value = "round-robin")]
pub router_mode: RouterMode,
/// Maximum number of batched tokens for KV routing
/// Needed for informing the KV router
/// TODO: derive from vllm args
/// NOTE: this is not actually used for now
#[arg(long, default_value = "8192")]
pub max_num_batched_tokens: Option<u32>,
/// KV Router: Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 2.0
#[arg(long)]
pub kv_overlap_score_weight: Option<f64>,
/// KV Router: Weight for GPU cache usage in worker selection.
/// Higher values avoid workers with nearly full KV caches. Default: 1.0
#[arg(long)]
pub kv_gpu_cache_usage_weight: Option<f64>,
/// KV Router: Weight for waiting requests in worker selection.
/// Higher values avoid workers with queued requests. Default: 1.0
/// KV Router: Temperature for worker sampling via softmax.
/// Higher values promote more randomness, and 0 fallbacks to deterministic.
/// Default: 0.5
#[arg(long)]
pub kv_waiting_requests_weight: Option<f64>,
pub router_temperature: Option<f64>,
/// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4).
......@@ -211,8 +214,8 @@ impl Flags {
self.router_mode.into(),
KvRouterConfig::new(
self.kv_overlap_score_weight,
self.kv_gpu_cache_usage_weight,
self.kv_waiting_requests_weight,
self.router_temperature,
self.max_num_batched_tokens,
),
)
}
......
......@@ -26,7 +26,14 @@ from vllm.entrypoints.openai.api_server import (
)
from vllm.inputs import TokensPrompt
from dynamo.llm import ModelType, WorkerMetricsPublisher, register_llm
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelType,
WorkerMetricsPublisher,
WorkerStats,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -70,15 +77,29 @@ class RequestHandler:
self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
self.metrics_publisher.publish(
0, # request_active_slots
1024, # request_total_slots
0, # kv_active_blocks
1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
# Create the structured metrics objects
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=1024,
num_requests_waiting=0,
data_parallel_rank=None,
)
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=1024,
gpu_cache_usage_perc=0.0,
gpu_prefix_cache_hit_rate=0.0,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats, kv_stats=kv_stats, spec_decode_stats=None
)
# Publish the metrics as a single object
self.metrics_publisher.publish(metrics)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
......
......@@ -72,7 +72,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Client>()?;
m.add_class::<EtcdClient>()?;
m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::disagg_router::DisaggregatedRouter>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?;
......
......@@ -29,51 +29,6 @@ use tracing;
use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig};
#[pyclass]
pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>,
}
#[pymethods]
impl KvRouter {
#[new]
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
};
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner = llm_rs::kv_router::KvRouter::new(
component.inner.clone(),
kv_block_size as u32,
None,
)
.await
.map_err(to_pyerr)?;
Ok(Self {
inner: Arc::new(inner),
})
})
}
fn schedule<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
lora_id: u64,
) -> PyResult<Bound<'p, PyAny>> {
let router = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker_id = router
.schedule(&token_ids, lora_id)
.await
.map_err(to_pyerr)?;
Ok(worker_id)
})
}
}
#[pyfunction]
pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> {
if kv_block_size == 0 {
......@@ -617,25 +572,34 @@ impl KvMetricsAggregator {
fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
// TODO: update EndpointKvMetrics to match the new ForwardPassMetrics struct
let endpoints = self.inner.get_endpoints();
let load_avg = endpoints.load_avg;
let load_std = endpoints.load_std;
let endpoint_kv_metrics = endpoints
.endpoints
.iter()
.map(|(worker_id, endpoint)| EndpointKvMetrics {
worker_id: *worker_id,
request_active_slots: endpoint.data.worker_stats.request_active_slots,
request_total_slots: endpoint.data.worker_stats.request_total_slots,
kv_active_blocks: endpoint.data.kv_stats.kv_active_blocks,
kv_total_blocks: endpoint.data.kv_stats.kv_total_blocks,
num_requests_waiting: endpoint.data.worker_stats.num_requests_waiting,
gpu_cache_usage_perc: endpoint.data.kv_stats.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: endpoint.data.kv_stats.gpu_prefix_cache_hit_rate,
.into_iter()
.map(|(worker_id, endpoint)| {
let metrics = endpoint.data;
let LoadMetrics::EngineLoadMetrics(fwd_pass_metrics) = metrics else {
panic!("Endpoints do not contain forward pass metrics.");
};
EndpointKvMetrics {
worker_id,
request_active_slots: fwd_pass_metrics.worker_stats.request_active_slots,
request_total_slots: fwd_pass_metrics.worker_stats.request_total_slots,
kv_active_blocks: fwd_pass_metrics.kv_stats.kv_active_blocks,
kv_total_blocks: fwd_pass_metrics.kv_stats.kv_total_blocks,
num_requests_waiting: fwd_pass_metrics.worker_stats.num_requests_waiting,
gpu_cache_usage_perc: fwd_pass_metrics.kv_stats.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: fwd_pass_metrics.kv_stats.gpu_prefix_cache_hit_rate,
}
})
.collect();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(AggregatedMetrics {
endpoints: endpoint_kv_metrics,
load_avg: endpoints.load_avg,
load_std: endpoints.load_std,
load_avg,
load_std,
})
})
}
......
......@@ -272,25 +272,6 @@ class Client:
"""
...
class KvRouter:
"""
A router will determine which worker should handle a given request.
"""
...
def __init__(self, drt: DistributedRuntime, component: Component) -> None:
"""
Create a `KvRouter` object that is associated with the `component`
"""
def schedule(self, token_ids: List[int], lora_id: int) -> int:
"""
Return the worker id that should handle the given token ids,
exception will be raised if there is no worker available.
"""
...
class DisaggregatedRouter:
"""
A router that determines whether to perform prefill locally or remotely based on
......
......@@ -32,7 +32,6 @@ from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouter as KvRouter
from dynamo._core import KvStats as KvStats
from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores
......
......@@ -212,8 +212,20 @@ impl ModelManager {
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> {
// Determine if we should use KV events based on overlap score weight
let use_kv_events = kv_router_config
.as_ref()
.map(|config| config.overlap_score_weight > 0.0)
.unwrap_or(false);
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new(component.clone(), kv_cache_block_size, Some(selector)).await?;
let chooser = KvRouter::new(
component.clone(),
kv_cache_block_size,
Some(selector),
use_kv_events,
)
.await?;
let new_kv_chooser = Arc::new(chooser);
self.kv_choosers
.lock()
......
......@@ -23,11 +23,12 @@ pub mod publisher;
pub mod recorder;
pub mod scheduler;
pub mod scoring;
pub mod sequence;
use crate::{
kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
metrics_aggregator::KvMetricsAggregator,
metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
......@@ -58,25 +59,20 @@ pub trait WorkerSelector {
/// KV Router configuration parameters
#[derive(Debug, Clone)]
pub struct KvRouterConfig {
/// Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 2.0
pub overlap_score_weight: f64,
/// Weight for GPU cache usage in worker selection.
/// Higher values avoid workers with nearly full KV caches. Default: 1.0
pub gpu_cache_usage_weight: f64,
pub router_temperature: f64,
/// Weight for waiting requests in worker selection.
/// Higher values avoid workers with queued requests. Default: 1.0
pub waiting_requests_weight: f64,
// note: this is not actually used for now
pub max_num_batched_tokens: u32,
}
impl Default for KvRouterConfig {
fn default() -> Self {
Self {
overlap_score_weight: 1.0,
gpu_cache_usage_weight: 1.0,
waiting_requests_weight: 1.0,
router_temperature: 0.5,
max_num_batched_tokens: 8192,
}
}
}
......@@ -86,16 +82,15 @@ impl KvRouterConfig {
/// If a weight is None, the default value will be used.
pub fn new(
overlap_score_weight: Option<f64>,
gpu_cache_usage_weight: Option<f64>,
waiting_requests_weight: Option<f64>,
temperature: Option<f64>,
max_num_batched_tokens: Option<u32>,
) -> Self {
let default = Self::default();
Self {
overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
gpu_cache_usage_weight: gpu_cache_usage_weight
.unwrap_or(default.gpu_cache_usage_weight),
waiting_requests_weight: waiting_requests_weight
.unwrap_or(default.waiting_requests_weight),
router_temperature: temperature.unwrap_or(default.router_temperature),
max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens),
}
}
}
......@@ -103,7 +98,7 @@ impl KvRouterConfig {
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
pub struct KvRouter {
indexer: KvIndexer,
indexer: Option<KvIndexer>,
scheduler: KvScheduler,
block_size: u32,
}
......@@ -113,16 +108,19 @@ impl KvRouter {
component: Component,
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
use_kv_events: bool,
) -> Result<Self> {
let cancellation_token = component
.drt()
.primary_lease()
.expect("Cannot KV route static workers")
.primary_token();
tracing::info!("KV Routing initialized");
let metrics_aggregator =
KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
EndpointCollector::new(component.clone(), cancellation_token.clone()).await;
let maybe_indexer =
use_kv_events.then(|| KvIndexer::new(cancellation_token.clone(), block_size));
let scheduler = KvScheduler::start(
component.namespace().clone(),
block_size,
......@@ -133,6 +131,7 @@ impl KvRouter {
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different.
if let Some(ref indexer) = maybe_indexer {
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_events_tx = indexer.event_sender();
......@@ -148,35 +147,31 @@ impl KvRouter {
}
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e);
tracing::debug!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
}
}
});
}
tracing::info!("KV Routing initialized");
Ok(Self {
indexer: maybe_indexer,
scheduler,
indexer,
block_size,
})
}
// [TODO] indexer needs to take 'lora_id' as parameter
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
// Extracting part of the code in KvRouter::generate() for only
// the decision making part, routing is done by the caller
let isl_tokens = token_ids.len();
let overlap_scores = self
.indexer
.find_matches_for_request(token_ids.as_slice())
.await?;
tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
}
/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returned overlap amount is in number of blocks.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> {
/// Now also takes context_id for request tracking
async fn find_best_match(
&self,
context_id: &str,
tokens: &[u32],
) -> anyhow::Result<(i64, u32)> {
let isl_tokens = tokens.len();
let block_size = self.block_size;
......@@ -187,13 +182,38 @@ impl KvRouter {
.into_iter()
.map(|block| LocalBlockHash(block.block_hash()))
.collect();
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let worker_id = self
let overlap_scores = match &self.indexer {
Some(indexer) => indexer.find_matches(local_block_hashes).await?,
None => Default::default(), // Returns empty/default instance
};
let best_worker_id = self
.scheduler
.schedule(overlap_scores.clone(), isl_tokens)
.schedule(
context_id.to_string(),
isl_tokens,
block_size,
tokens,
overlap_scores.clone(),
)
.await?;
let overlap_amount = overlap_scores.scores.get(&worker_id).copied().unwrap_or(0);
Ok((worker_id, overlap_amount))
let overlap_amount = overlap_scores
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0);
Ok((best_worker_id, overlap_amount))
}
/// Push a token to a specific request's sequence
pub async fn push(&self, request_id: &String, token: u32) {
self.scheduler.push(request_id, token).await
}
/// Free all blocks associated with a request
pub async fn free(&self, request_id: &String) {
self.scheduler.free(request_id).await
}
/// Get the block size this router was configured with
......@@ -202,6 +222,7 @@ impl KvRouter {
}
}
// NOTE: this would not be usable for now, should deprecate
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
async fn generate(
......@@ -209,7 +230,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts();
let (worker_id, _) = self.find_best_match(&request.tokens).await?;
let (worker_id, _) = self.find_best_match(ctx.id(), &request.tokens).await?;
let response = RouterResponse { worker_id };
let response = Annotated::from_data(response);
......@@ -243,13 +264,40 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
match self.inner.client.instance_source.as_ref() {
InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => {
let (instance_id, overlap_amount) =
self.chooser.find_best_match(&request.token_ids).await?;
// Extract context ID for request tracking
let context_id = request.context().id().to_string();
let (instance_id, overlap_amount) = self
.chooser
.find_best_match(&context_id, &request.token_ids)
.await?;
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
self.inner.direct(updated_request, instance_id).await
// Get the response stream from the worker
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens
let stream_context = response_stream.context();
let chooser = self.chooser.clone();
let request_id = context_id.clone();
let wrapped_stream = Box::pin(async_stream::stream! {
while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response
if let Some(ref output) = item.data {
for token_id in &output.token_ids {
chooser.push(&request_id, *token_id).await;
}
}
yield item;
}
chooser.free(&request_id).await;
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
}
}
......
......@@ -15,7 +15,7 @@
use std::sync::Once;
pub use crate::kv_router::protocols::ForwardPassMetrics;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::KV_METRICS_ENDPOINT;
use crate::kv_router::scheduler::Endpoint;
......@@ -28,6 +28,37 @@ use tokio_util::sync::CancellationToken;
static METRICS_WAITING_MESSAGE: Once = Once::new();
static METRICS_FOUND_MESSAGE: Once = Once::new();
pub struct EndpointCollector {
pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
}
impl EndpointCollector {
pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
let (watch_tx, watch_rx) = watch::channel(ProcessedEndpoints::default());
tokio::spawn(collect_endpoints_task(
component.clone(),
watch_tx,
cancellation_token.clone(),
"generate".to_string(),
));
Self {
service_name: component.service_name(),
endpoints_rx: watch_rx,
}
}
pub fn get_endpoints(&self) -> ProcessedEndpoints {
self.endpoints_rx.borrow().clone()
}
pub fn endpoints_watcher(&self) -> watch::Receiver<ProcessedEndpoints> {
self.endpoints_rx.clone()
}
}
pub struct KvMetricsAggregator {
pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
......@@ -41,6 +72,7 @@ impl KvMetricsAggregator {
component.clone(),
watch_tx,
cancellation_token.clone(),
KV_METRICS_ENDPOINT.to_string(),
));
Self {
......@@ -93,12 +125,16 @@ pub async fn collect_endpoints_task(
component: Component,
watch_tx: watch::Sender<ProcessedEndpoints>,
cancel: CancellationToken,
subject: String,
) {
let backoff_delay = Duration::from_millis(100);
let scrape_timeout = Duration::from_millis(300);
let endpoint = component.endpoint(KV_METRICS_ENDPOINT);
let endpoint = component.endpoint(&subject);
let service_subject = endpoint.subject();
// Keep track of the last sent value to avoid unnecessary updates
let mut last_sent: Option<ProcessedEndpoints> = None;
loop {
tokio::select! {
_ = cancel.cancelled() => {
......@@ -115,31 +151,59 @@ pub async fn collect_endpoints_task(
continue;
}
};
let endpoints: Vec<Endpoint> = unfiltered_endpoints
let endpoints: Vec<Endpoint> = if subject == KV_METRICS_ENDPOINT {
// Original filtering behavior
unfiltered_endpoints
.into_iter()
.filter(|s| s.data.is_some())
.filter_map(|s|
match s.data.unwrap().decode::<ForwardPassMetrics>() {
Ok(data) => Some(Endpoint {
.filter_map(|s| {
s.data?
.decode::<ForwardPassMetrics>()
.map(|data| Endpoint {
name: s.name,
subject: s.subject,
data,
}),
Err(e) => {
tracing::debug!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e);
None
}
}
)
.collect();
data: LoadMetrics::EngineLoadMetrics(data),
})
.inspect_err(|e| {
tracing::warn!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e);
})
.ok()
})
.collect()
} else {
// No filtering - just use default LoadMetrics
unfiltered_endpoints
.into_iter()
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: LoadMetrics::default(),
})
.collect()
};
tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len());
let processed = ProcessedEndpoints::new(endpoints);
if watch_tx.send(processed).is_err() {
tracing::trace!("failed to send processed endpoints; shutting down");
// Only send if different from last sent value
// This is necessary because the watch channel does not track changes
// https://docs.rs/tokio/latest/tokio/sync/watch/struct.Receiver.html#method.has_changed
let should_send = match &last_sent {
Some(last) => last != &processed,
None => true,
};
if should_send {
tracing::trace!("Endpoints changed, sending update for service: {service_subject}");
if watch_tx.send(processed.clone()).is_err() {
tracing::error!("failed to send processed endpoints; shutting down");
break;
}
last_sent = Some(processed);
} else {
tracing::trace!("Endpoints unchanged, skipping update for service: {service_subject}");
}
}
}
}
......
......@@ -39,14 +39,14 @@ pub struct WorkerSelectionResult {
pub overlap_blocks: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct ForwardPassMetrics {
pub worker_stats: WorkerStats,
pub kv_stats: KvStats,
pub spec_decode_stats: Option<SpecDecodeStats>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct WorkerStats {
// https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models
pub data_parallel_rank: Option<u32>,
......@@ -55,7 +55,7 @@ pub struct WorkerStats {
pub num_requests_waiting: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct KvStats {
pub kv_active_blocks: u64,
pub kv_total_blocks: u64,
......@@ -65,7 +65,34 @@ pub struct KvStats {
pub gpu_prefix_cache_hit_rate: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct PredictiveLoadMetrics {
pub kv_active_blocks: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum LoadMetrics {
EngineLoadMetrics(ForwardPassMetrics),
PredictiveLoadMetrics(PredictiveLoadMetrics),
}
impl LoadMetrics {
pub fn kv_active_blocks(&self) -> u64 {
match self {
LoadMetrics::EngineLoadMetrics(metrics) => metrics.kv_stats.kv_active_blocks,
LoadMetrics::PredictiveLoadMetrics(metrics) => metrics.kv_active_blocks,
}
}
}
impl Default for LoadMetrics {
fn default() -> Self {
LoadMetrics::PredictiveLoadMetrics(PredictiveLoadMetrics::default())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct SpecDecodeStats {
pub num_spec_tokens: Option<u32>,
pub num_drafts: Option<u32>,
......
......@@ -17,16 +17,21 @@ use dynamo_runtime::component::Namespace;
use dynamo_runtime::traits::events::EventPublisher;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use super::protocols::WorkerSelectionResult;
use super::WorkerSelector;
use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::protocols::ForwardPassMetrics;
use crate::kv_router::indexer::WorkerId;
use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::scoring::ProcessedEndpoints;
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
use crate::kv_router::KvRouterConfig;
use crate::kv_router::KV_HIT_RATE_SUBJECT;
use crate::tokens::TokenBlockSequence;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
......@@ -49,11 +54,11 @@ pub enum KvSchedulerError {
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Endpoint {
pub name: String,
pub subject: String,
pub data: ForwardPassMetrics,
pub data: LoadMetrics,
}
impl Endpoint {
......@@ -71,22 +76,30 @@ impl Endpoint {
}
}
#[derive(Debug)]
pub struct SchedulingResponse {
pub best_worker_id: i64,
pub endpoints_changed: Option<Vec<i64>>,
}
pub struct SchedulingRequest {
pub isl_tokens: usize,
pub overlap: OverlapScores,
resp_tx: tokio::sync::oneshot::Sender<i64>,
pub potential_blocks: HashMap<i64, usize>,
resp_tx: tokio::sync::oneshot::Sender<SchedulingResponse>,
}
impl SchedulingRequest {
pub fn respond(self, worker_id: i64) {
if self.resp_tx.send(worker_id).is_err() {
tracing::trace!("failed to send response to requestor");
pub fn respond(self, response: SchedulingResponse) {
if self.resp_tx.send(response).is_err() {
tracing::error!("failed to send response to requestor");
}
}
}
pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
sequences: Arc<Mutex<ActiveSequencesMultiWorker>>,
}
impl KvScheduler {
......@@ -110,12 +123,18 @@ impl KvScheduler {
}
});
let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new(
block_size as usize,
endpoints.worker_ids(),
)));
// Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
// Background task to handle scheduling requests
tokio::spawn(async move {
let mut request: SchedulingRequest;
let mut request_rx = request_rx;
let mut pending_endpoint_update: Option<Vec<i64>> = None;
tracing::trace!("scheduler background task started");
'outer: loop {
......@@ -137,30 +156,33 @@ impl KvScheduler {
_ = endpoints_rx.changed() => {
endpoints = endpoints_rx.borrow_and_update().clone();
pending_endpoint_update = Some(endpoints.worker_ids());
continue 'outer;
}
};
loop {
match selector.select_worker(&endpoints, &request, block_size) {
Ok(selection) => {
let worker_id = process_worker_selection(
endpoints.borrow_mut(),
selection,
&event_tx,
);
request.respond(worker_id);
if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: selection.worker_id,
isl_blocks: selection.required_blocks as usize,
overlap_blocks: selection.overlap_blocks,
}) {
tracing::warn!("Failed to send KV hit rate event: {:?}", e);
}
let response = SchedulingResponse {
best_worker_id: selection.worker_id,
endpoints_changed: pending_endpoint_update.take(),
};
request.respond(response);
continue 'outer;
}
Err(KvSchedulerError::AllWorkersBusy) => {
tracing::trace!("all workers busy; waiting for more capacity");
match endpoints_rx.changed().await {
Ok(_) => {}
Err(e) => {
tracing::error!("error waiting for endpoints change: {:?}", e);
break 'outer;
}
};
endpoints = endpoints_rx.borrow_and_update().clone();
tokio::time::sleep(Duration::from_millis(5)).await;
continue;
}
Err(e) => {
tracing::error!("error scheduling request: {:?}", e);
......@@ -173,59 +195,81 @@ impl KvScheduler {
tracing::trace!("background endpoint subscriber shutting down");
});
Ok(KvScheduler { request_tx })
Ok(KvScheduler {
request_tx,
sequences,
})
}
pub async fn schedule(
&self,
overlap: OverlapScores,
request_id: String,
isl_tokens: usize,
block_size: u32,
tokens: &[u32],
overlap: OverlapScores,
) -> Result<i64, KvSchedulerError> {
let mut sequences = self.sequences.lock().await;
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
let potential_blocks = sequences.potential_blocks(token_sequence);
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
isl_tokens,
overlap,
potential_blocks,
resp_tx,
};
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
let res = resp_rx
let response = resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
Ok(res)
if let Some(new_worker_ids) = response.endpoints_changed {
sequences.update_workers(new_worker_ids);
}
}
// This becomes the driver function that handles the selection result
pub fn process_worker_selection(
workers: &mut ProcessedEndpoints,
selection: WorkerSelectionResult,
event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
) -> i64 {
let worker = workers
.endpoints
.get_mut(&selection.worker_id)
.expect("worker not found");
// Update worker state predictively
// Will be overwritten on next polling of metrics
worker.data.kv_stats.kv_active_blocks += selection
.required_blocks
.saturating_sub(selection.overlap_blocks as u64);
// Emit event
if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: selection.worker_id,
isl_blocks: selection.required_blocks as usize,
overlap_blocks: selection.overlap_blocks,
}) {
tracing::warn!("Failed to send KV hit rate event: {:?}", e);
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
sequences.add_request(request_id, token_sequence, response.best_worker_id);
Ok(response.best_worker_id)
}
selection.worker_id
/// Find the potential blocks for each worker if the sequence were routed there
pub async fn potential_blocks(
&self,
token_sequence: TokenBlockSequence,
) -> HashMap<i64, usize> {
let sequences = self.sequences.lock().await;
sequences.potential_blocks(token_sequence)
}
/// Add a new request with its initial tokens to a specific worker
pub async fn add_request(
&self,
request_id: String,
token_sequence: TokenBlockSequence,
worker_id: WorkerId,
) {
let mut sequences = self.sequences.lock().await;
sequences.add_request(request_id, token_sequence, worker_id)
}
/// Push a token to a specific request's sequence
pub async fn push(&self, request_id: &String, token: u32) {
let mut sequences = self.sequences.lock().await;
sequences.push(request_id, token)
}
/// Free all blocks associated with a request
pub async fn free(&self, request_id: &String) {
let mut sequences = self.sequences.lock().await;
sequences.free(request_id)
}
}
// Helper function for softmax sampling
......@@ -234,6 +278,24 @@ fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
panic!("Empty logits for softmax sampling");
}
// Guard: if temperature is 0, return the key with the smallest logit value
if temperature == 0.0 {
// Find the minimum logit value
let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));
// Collect all keys with the minimum logit value (to handle ties)
let min_keys: Vec<_> = logits
.iter()
.filter(|(_, &v)| v == min_logit)
.map(|(k, _)| *k)
.collect();
// Randomly select from the minimum keys (handles single key case naturally)
let mut rng = rand::rng();
let index = rng.random_range(0..min_keys.len());
return min_keys[index];
}
let keys: Vec<_> = logits.keys().copied().collect();
let values: Vec<_> = logits.values().copied().collect();
......@@ -309,57 +371,42 @@ impl WorkerSelector for DefaultWorkerSelector {
}
let request_blocks = request.isl_tokens.div_ceil(block_size as usize);
let potential_active_blocks = &request.potential_blocks;
let mut worker_logits = HashMap::new();
let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker
for (worker_id, ep) in workers.endpoints.iter() {
let worker_id = *worker_id;
// Get overlap blocks for this worker
let overlap_blocks =
request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as f64;
let new_blocks = request_blocks as f64 - overlap_blocks;
let kv_total_blocks = ep.data.kv_stats.kv_total_blocks as f64;
assert!(kv_total_blocks > 0.0);
for (worker_id, _) in workers.endpoints.iter() {
let cached_blocks = request.overlap.scores.get(worker_id).copied().unwrap_or(0) as f64;
let prefill_blocks = request_blocks as f64 - cached_blocks;
let normalized_new_blocks = new_blocks / kv_total_blocks;
let gpu_cache_usage = ep.data.kv_stats.gpu_cache_usage_perc as f64;
let num_requests_waiting = ep.data.worker_stats.num_requests_waiting as f64;
// this is the number of blocks each worker would have if the request were scheduled there
let potential_blocks = *potential_active_blocks.get(worker_id).unwrap_or_else(||
{tracing::warn!("assuming 0 decoding blocks for {worker_id}, as the load metrics endpoint does not exist yet");
&0
}) as f64;
// Calculate logit (lower is better)
let logit = self.kv_router_config.overlap_score_weight * normalized_new_blocks
+ self.kv_router_config.gpu_cache_usage_weight * gpu_cache_usage
+ self.kv_router_config.waiting_requests_weight * num_requests_waiting;
let logit =
self.kv_router_config.overlap_score_weight * prefill_blocks + potential_blocks;
max_logit = max_logit.max(logit);
worker_logits.insert(worker_id, logit);
worker_logits.insert(*worker_id, logit);
tracing::info!(
"Formula for {worker_id}: {logit:.3} = {:.1} * {normalized_new_blocks:.3} + {:.1} * {gpu_cache_usage:.3} + {:.1} * {num_requests_waiting:.3}",
"Formula for {worker_id}: {logit:.3} = {:.1} * {prefill_blocks:.3} + {potential_blocks:.3}",
self.kv_router_config.overlap_score_weight,
self.kv_router_config.gpu_cache_usage_weight,
self.kv_router_config.waiting_requests_weight,
);
}
// Return early if no valid workers found
if worker_logits.is_empty() || worker_logits.values().all(|&v| v == 0.0) {
tracing::warn!("All worker logits are zero. Fallback to random routing.");
// Pick random worker
let mut rng = rand::rng();
let worker_ids: Vec<_> = workers.endpoints.keys().copied().collect();
let worker_id = worker_ids[rng.random_range(0..worker_ids.len())];
let overlap_blocks =
request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;
return Ok(WorkerSelectionResult {
worker_id,
required_blocks: request_blocks as u64,
overlap_blocks,
});
// Normalize by dividing by max value
for logit in worker_logits.values_mut() {
*logit /= max_logit;
}
// Use softmax sampling to select worker
let temperature = 1.0; // You can make this configurable if needed
let temperature = self.kv_router_config.router_temperature;
let best_worker_id = softmax_sample(&worker_logits, temperature);
let overlap_blocks = request
......@@ -371,7 +418,7 @@ impl WorkerSelector for DefaultWorkerSelector {
let best_logit = worker_logits[&best_worker_id];
tracing::info!(
"Selected worker: {}, logit: {:.3}",
"Selected worker: {}, normalized logit: {:.3}",
best_worker_id,
best_logit
);
......@@ -387,7 +434,6 @@ impl WorkerSelector for DefaultWorkerSelector {
#[cfg(test)]
mod tests {
use super::*;
use crate::kv_router::protocols::{KvStats, WorkerStats};
#[test]
fn test_softmax_sample_single_key() {
......@@ -416,191 +462,31 @@ mod tests {
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
}
// Helper to create a worker endpoint
fn create_endpoint(
worker_id: i64,
gpu_cache_usage_perc: f32,
num_requests_waiting: u64,
) -> Endpoint {
Endpoint {
name: format!("worker-{}", worker_id),
subject: format!("worker-subject-{:x}", worker_id),
data: ForwardPassMetrics {
kv_stats: KvStats {
gpu_cache_usage_perc,
..Default::default()
},
worker_stats: WorkerStats {
num_requests_waiting,
..Default::default()
},
// Other fields can be default initialized for this test
..Default::default()
},
}
}
// Helper to create ProcessedEndpoints
struct WorkerInfo {
id: i64,
usage: f32,
waiting: u64,
}
fn create_workers(workers: Vec<WorkerInfo>) -> ProcessedEndpoints {
let mut endpoints = HashMap::new();
for worker in workers {
endpoints.insert(
worker.id,
create_endpoint(worker.id, worker.usage, worker.waiting),
#[test]
fn test_softmax_sample_zero_temperature() {
// Test that with temperature 0, softmax_sample returns the key with smallest logit
let mut logits = HashMap::new();
logits.insert(1, 5.0);
logits.insert(2, 3.0); // This has the smallest logit
logits.insert(3, 7.0);
logits.insert(4, 3.5);
// With temperature 0, should always return worker 2 (smallest logit)
for _ in 0..10 {
let result = softmax_sample(&logits, 0.0);
assert_eq!(
result, 2,
"Should return worker with smallest logit when temperature is 0"
);
}
ProcessedEndpoints {
endpoints,
load_avg: 0.0,
load_std: 0.0,
}
}
// Helper to create a scheduling request
struct WorkerOverlap {
worker_id: i64,
overlap_blocks: u32,
}
fn create_request(overlaps: Vec<WorkerOverlap>, isl_tokens: usize) -> SchedulingRequest {
SchedulingRequest {
isl_tokens,
overlap: OverlapScores {
scores: overlaps
.into_iter()
.map(|wo| (wo.worker_id, wo.overlap_blocks))
.collect(),
frequencies: vec![],
},
resp_tx: tokio::sync::oneshot::channel().0,
}
}
// Test with negative values
logits.clear();
logits.insert(10, -1.0);
logits.insert(20, -5.0); // This has the smallest logit
logits.insert(30, 0.0);
#[test]
fn test_no_endpoints() {
let workers = create_workers(vec![]);
let request = create_request(vec![], 100);
let selector = DefaultWorkerSelector::new(None);
let block_size = 20;
match selector.select_worker(&workers, &request, block_size) {
Err(KvSchedulerError::NoEndpoints) => {} // Expected
_ => panic!("Should return NoEndpoints error"),
}
}
// #[test]
// fn test_select_worker_basic() {
// // Setup workers
// let workers = create_workers(vec![
// WorkerInfo {
// id: 1,
// usage: 0.50,
// waiting: 1,
// },
// WorkerInfo {
// id: 2,
// usage: 0.80,
// waiting: 0,
// },
// ]);
// // Setup request: 100 tokens, block_size=20 (5 blocks)
// let request = create_request(
// vec![
// WorkerOverlap {
// worker_id: 1,
// overlap_blocks: 3,
// },
// WorkerOverlap {
// worker_id: 2,
// overlap_blocks: 4,
// },
// ],
// 100,
// );
// let selector = DefaultWorkerSelector::new(None);
// let block_size = 20;
// // Execute selection
// let result = selector
// .select_worker(&workers, &request, block_size)
// .expect("Should select a worker");
// // Worker 2 should win because:
// // Worker1: 2.0 * 0.600 - 1.0 * 0.500 - 1.0 * 1.000 = -0.3
// // Worker2: 2.0 * 0.800 - 1.0 * 0.800 - 1.0 * 0.000 = 0.8
// assert_eq!(result.worker_id, 2);
// assert_eq!(result.required_blocks, 5); // 100 tokens / 20 block_size
// assert_eq!(result.overlap_blocks, 4);
// }
// #[test]
// fn test_no_overlap_scores() {
// // Workers exist but request has no overlap scores
// let workers = create_workers(vec![WorkerInfo {
// id: 1,
// usage: 0.50,
// waiting: 1,
// }]);
// let request = create_request(vec![], 100); // No overlaps
// let selector = DefaultWorkerSelector::new(None);
// let block_size = 20;
// let result = selector
// .select_worker(&workers, &request, block_size)
// .expect("Should fallback to selecting worker");
// // Worker1 should be selected with 0 overlap
// assert_eq!(result.worker_id, 1);
// assert_eq!(result.overlap_blocks, 0);
// }
// #[test]
// fn test_custom_weights() {
// // Setup workers
// let workers = create_workers(vec![
// WorkerInfo {
// id: 1,
// usage: 0.50,
// waiting: 1,
// },
// WorkerInfo {
// id: 2,
// usage: 0.80,
// waiting: 0,
// },
// ]);
// // Custom config with high priority on GPU usage
// let config = KvRouterConfig {
// gpu_cache_usage_weight: 10.0, // Very high weight
// overlap_score_weight: 2.0, // just current defaults
// waiting_requests_weight: 1.0,
// };
// let selector = DefaultWorkerSelector::new(Some(config));
// let request = create_request(
// vec![
// WorkerOverlap {
// worker_id: 1,
// overlap_blocks: 3,
// },
// WorkerOverlap {
// worker_id: 2,
// overlap_blocks: 4,
// },
// ],
// 100,
// );
// let block_size = 20;
// let result = selector
// .select_worker(&workers, &request, block_size)
// .expect("Should select worker");
// assert_eq!(result.worker_id, 1);
// }
let result = softmax_sample(&logits, 0.0);
assert_eq!(result, 20, "Should handle negative logits correctly");
}
}
......@@ -20,7 +20,7 @@ use std::collections::HashMap;
use crate::kv_router::scheduler::Endpoint;
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct ProcessedEndpoints {
pub endpoints: HashMap<i64, Endpoint>,
pub load_avg: f64,
......@@ -32,7 +32,7 @@ impl ProcessedEndpoints {
// compute some basic statistics
let load_values: Vec<f64> = endpoints
.iter()
.map(|x| x.data.kv_stats.kv_active_blocks as f64)
.map(|endpoint| endpoint.data.kv_active_blocks() as f64)
.collect();
let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64;
let variance = load_values
......@@ -50,4 +50,15 @@ impl ProcessedEndpoints {
load_std,
}
}
pub fn worker_ids(&self) -> Vec<i64> {
self.endpoints.keys().copied().collect()
}
pub fn active_blocks(&self) -> HashMap<i64, usize> {
self.endpoints
.iter()
.map(|(&worker_id, endpoint)| (worker_id, endpoint.data.kv_active_blocks() as usize))
.collect()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! KV Cache Sequence Management for LLM Inference
//!
//! This module provides efficient management of token sequences and their associated KV cache blocks
//! for distributed LLM inference. It implements a shared block system where multiple requests can
//! reuse the same KV cache blocks for common token prefixes, significantly reducing memory usage.
//!
//! # Key Components
//!
//! - [`ActiveSequences`]: Single-threaded sequence manager that tracks active requests and their
//! token sequences, managing shared KV cache blocks efficiently.
//!
//! - [`ActiveSequencesMultiWorker`]: Multi-threaded extension that distributes sequence management
//! across multiple worker threads, enabling parallel processing of requests while maintaining
//! consistency.
//!
//! # Architecture
//!
//! The system uses a block-based approach where token sequences are divided into fixed-size blocks.
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).
use crate::kv_router::indexer::WorkerId;
use crate::tokens::blocks::UniqueBlock;
use crate::tokens::TokenBlockSequence;
use derive_getters::Getters;
use std::collections::{HashMap, HashSet};
use std::sync::{mpsc, Arc};
use std::thread;
use std::time::Duration;
use uuid;
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;
/// Create unique blocks from a TokenBlockSequence
fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: usize,
) -> Vec<UniqueBlock> {
let mut unique_blocks: Vec<UniqueBlock> = tokens
.blocks()
.iter()
.map(|block| UniqueBlock::FullBlock(block.sequence_hash()))
.collect();
// Only push the partial block if tokens count isn't a multiple of block_size
if tokens.total_tokens() % block_size != 0 {
unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
});
}
unique_blocks
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
active_seqs: HashMap<RequestId, TokenBlockSequence>,
partial_blocks: HashMap<RequestId, UniqueBlock>,
unique_blocks: HashMap<UniqueBlock, HashSet<RequestId>>,
#[getter(copy)]
block_size: usize,
#[getter(copy)]
active_blocks: usize,
}
impl ActiveSequences {
/// Create a new SharedSequenceManager instance
pub fn new(block_size: usize) -> Self {
// TODO: make this not a hard req
assert!(block_size > 1, "block_size must be greater than 1");
Self {
active_seqs: HashMap::new(),
partial_blocks: HashMap::new(),
unique_blocks: HashMap::new(),
block_size,
active_blocks: 0,
}
}
fn add_block(&mut self, request_id: RequestId, block: &UniqueBlock) {
let is_new_block = !self.unique_blocks.contains_key(block);
self.unique_blocks
.entry(block.clone())
.or_default()
.insert(request_id.clone());
if is_new_block {
self.active_blocks += 1;
}
if matches!(block, UniqueBlock::PartialBlock(_)) {
self.partial_blocks.insert(request_id, block.clone());
};
}
fn remove_block(&mut self, request_id: &RequestId, block: &UniqueBlock) {
let Some(request_ids) = self.unique_blocks.get_mut(block) else {
panic!("Cannot remove a block that does not exist.")
};
// Remove the unique block if no more requests using it
request_ids.retain(|w| w != request_id);
if request_ids.is_empty() {
self.active_blocks -= 1;
self.unique_blocks.remove(block);
}
}
/// Add a new request with its initial tokens
pub fn add_request(
&mut self,
request_id: RequestId,
token_sequence: TokenBlockSequence,
) -> usize {
let blocks = create_unique_blocks_from_sequence(&token_sequence, None, self.block_size);
for block in &blocks {
self.add_block(request_id.clone(), block);
}
self.active_seqs.insert(request_id.clone(), token_sequence);
self.active_blocks
}
/// Match a request against existing blocks and return the number of new blocks that would be added
pub fn new_blocks(&self, token_sequence: &TokenBlockSequence) -> usize {
let blocks = create_unique_blocks_from_sequence(token_sequence, None, self.block_size);
blocks
.iter()
.filter(|block| !self.unique_blocks.contains_key(block))
.count()
}
/// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &TokenBlockSequence) -> usize {
self.new_blocks(token_sequence) + self.active_blocks
}
/// Free all blocks associated with a request
pub fn free(&mut self, request_id: &RequestId) -> usize {
let Some(token_seq) = self.active_seqs.get(request_id) else {
tracing::warn!("Trying to free free non-existent request {request_id}");
return 0;
};
let blocks = create_unique_blocks_from_sequence(token_seq, None, self.block_size);
for block in blocks {
if matches!(block, UniqueBlock::FullBlock(_)) {
self.remove_block(request_id, &block);
}
}
if let Some(partial_block) = self.partial_blocks.remove(request_id) {
self.remove_block(request_id, &partial_block);
}
self.active_seqs.remove(request_id).unwrap();
self.active_blocks
}
/// Push a token to a specific request's sequence
pub fn push(&mut self, request_id: &RequestId, token: u32) -> usize {
let token_seq = self
.active_seqs
.get_mut(request_id)
.expect("Request ID not found for token push");
token_seq.append(token).expect("Token push failed.");
// No need to update anything
if token_seq.total_tokens() % self.block_size != 1 {
return self.active_blocks;
}
let last_seq_hash = token_seq
.last_complete_block()
.map(|block| block.sequence_hash());
// Promote a partial block into a full block if not already
if let Some(partial_block) = self.partial_blocks.get(request_id).cloned() {
self.remove_block(request_id, &partial_block);
}
if let Some(full_block) = last_seq_hash {
self.add_block(request_id.clone(), &UniqueBlock::FullBlock(full_block));
}
self.add_block(request_id.clone(), &UniqueBlock::default());
self.active_blocks
}
}
#[derive(Debug)]
enum UpdateSequences {
AddRequest {
request_id: RequestId,
token_sequence: TokenBlockSequence,
},
Free {
request_id: RequestId,
},
Push {
request_id: RequestId,
token: u32,
},
NewBlocks {
token_sequence: Arc<TokenBlockSequence>,
resp_tx: mpsc::SyncSender<usize>,
},
PotentialBlocks {
token_sequence: Arc<TokenBlockSequence>,
resp_tx: mpsc::SyncSender<usize>,
},
ActiveBlocks {
resp_tx: mpsc::SyncSender<usize>,
},
Shutdown,
}
/// Multi-worker extension of ActiveSequences that distributes requests across multiple threads
pub struct ActiveSequencesMultiWorker {
senders: HashMap<WorkerId, mpsc::Sender<UpdateSequences>>,
request_to_worker: HashMap<RequestId, WorkerId>,
handles: HashMap<WorkerId, thread::JoinHandle<()>>,
block_size: usize,
}
impl ActiveSequencesMultiWorker {
pub fn new(block_size: usize, worker_ids: Vec<WorkerId>) -> Self {
assert!(block_size > 1, "block_size must be greater than 1");
let mut senders = HashMap::new();
let mut handles = HashMap::new();
for worker_id in worker_ids {
let (sender, handle) = Self::start_worker(block_size);
senders.insert(worker_id, sender);
handles.insert(worker_id, handle);
}
Self {
senders,
request_to_worker: HashMap::new(),
handles,
block_size,
}
}
/// Helper method to start a worker thread
fn start_worker(block_size: usize) -> (mpsc::Sender<UpdateSequences>, thread::JoinHandle<()>) {
let (request_tx, request_rx) = mpsc::channel::<UpdateSequences>();
let handle = thread::spawn(move || {
let mut active_sequences = ActiveSequences::new(block_size);
while let Ok(command) = request_rx.recv() {
match command {
UpdateSequences::AddRequest {
request_id,
token_sequence,
} => {
active_sequences.add_request(request_id, token_sequence);
}
UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id);
}
UpdateSequences::Push { request_id, token } => {
active_sequences.push(&request_id, token);
}
UpdateSequences::NewBlocks {
token_sequence,
resp_tx,
} => {
let new_blocks = active_sequences.new_blocks(&token_sequence);
let _ = resp_tx.send(new_blocks);
}
UpdateSequences::PotentialBlocks {
token_sequence,
resp_tx,
} => {
let potential_blocks = active_sequences.potential_blocks(&token_sequence);
let _ = resp_tx.send(potential_blocks);
}
UpdateSequences::ActiveBlocks { resp_tx } => {
let active_blocks = active_sequences.active_blocks();
let _ = resp_tx.send(active_blocks);
}
UpdateSequences::Shutdown => {
break;
}
}
}
});
(request_tx, handle)
}
/// Update the set of workers, adding and removing as needed
pub fn update_workers(&mut self, new_worker_ids: Vec<WorkerId>) -> HashMap<WorkerId, usize> {
let current_workers: HashSet<WorkerId> = self.senders.keys().copied().collect();
let new_workers: HashSet<WorkerId> = new_worker_ids.into_iter().collect();
let workers_to_remove: Vec<WorkerId> =
current_workers.difference(&new_workers).copied().collect();
let workers_to_add: Vec<WorkerId> =
new_workers.difference(&current_workers).copied().collect();
// Remove workers
for worker_id in &workers_to_remove {
tracing::warn!("Removing worker {}", worker_id);
// Send shutdown command to the worker
if let Some(sender) = self.senders.remove(worker_id) {
let _ = sender.send(UpdateSequences::Shutdown);
}
if let Some(handle) = self.handles.remove(worker_id) {
let _ = handle.join();
}
}
// Add new workers
for worker_id in &workers_to_add {
tracing::warn!("Adding worker {}", worker_id);
let (sender, handle) = Self::start_worker(self.block_size);
self.senders.insert(*worker_id, sender);
self.handles.insert(*worker_id, handle);
}
// Return active blocks for all workers
self.active_blocks()
}
pub fn add_request(
&mut self,
request_id: RequestId,
token_sequence: TokenBlockSequence,
worker_id: WorkerId,
) {
if !self.senders.contains_key(&worker_id) {
panic!("Worker ID {worker_id} not found");
}
self.request_to_worker.insert(request_id.clone(), worker_id);
self.senders[&worker_id]
.send(UpdateSequences::AddRequest {
request_id,
token_sequence,
})
.expect("Failed to send add_request command to worker");
}
pub fn free(&mut self, request_id: &RequestId) {
let worker_id = self
.request_to_worker
.get(request_id)
.copied()
.expect("Request ID not found in request_to_worker mapping");
self.senders[&worker_id]
.send(UpdateSequences::Free {
request_id: request_id.clone(),
})
.expect("Failed to send free command to worker");
self.request_to_worker.remove(request_id);
}
pub fn push(&mut self, request_id: &RequestId, token: u32) {
let worker_id = self
.request_to_worker
.get(request_id)
.copied()
.expect("Request ID not found in request_to_worker mapping");
self.senders[&worker_id]
.send(UpdateSequences::Push {
request_id: request_id.clone(),
token,
})
.expect("Failed to send push command to worker");
}
/// Get the number of workers
pub fn num_workers(&self) -> usize {
self.senders.len()
}
/// Generic method to query all workers with a given command
fn query_workers(
&self,
token_sequence: Option<TokenBlockSequence>,
command_fn: impl Fn(Option<Arc<TokenBlockSequence>>, mpsc::SyncSender<usize>) -> UpdateSequences,
) -> HashMap<WorkerId, usize> {
let mut results = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new();
// Send queries to all workers in parallel
for (worker_id, sender) in &self.senders {
let (resp_tx, resp_rx) = mpsc::sync_channel(0);
receivers.push((worker_id, resp_rx));
sender
.send(command_fn(token_sequence_shared.clone(), resp_tx))
.expect("Failed to send command to worker");
}
// Collect results from all workers
for (worker_id, receiver) in receivers {
let result = receiver
.recv_timeout(Duration::from_secs(1))
.expect("Failed to receive response from worker");
results.insert(*worker_id, result);
}
results
}
/// Query all workers for the number of new blocks that would be added by a token sequence
pub fn new_blocks(&self, token_sequence: TokenBlockSequence) -> HashMap<WorkerId, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::NewBlocks {
token_sequence: ts,
resp_tx,
},
None => unreachable!("token_sequence should always be Some for new_blocks"),
})
}
/// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
pub fn potential_blocks(&self, token_sequence: TokenBlockSequence) -> HashMap<WorkerId, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::PotentialBlocks {
token_sequence: ts,
resp_tx,
},
None => unreachable!("token_sequence should always be Some for potential_blocks"),
})
}
/// Query all workers for their current number of active blocks
pub fn active_blocks(&self) -> HashMap<WorkerId, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
}
}
impl Drop for ActiveSequencesMultiWorker {
fn drop(&mut self) {
// Send shutdown command to all workers
for sender in self.senders.values() {
let _ = sender.send(UpdateSequences::Shutdown);
}
// Wait for all threads to finish
for (_, handle) in self.handles.drain() {
let _ = handle.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokens::Tokens;
#[test]
fn test_shared_sequence_manager_operations() {
let block_size = 4;
let mut manager = ActiveSequences::new(block_size);
let to_sequence =
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
manager.add_request("0".to_string(), to_sequence(vec![0, 1, 2]));
manager.push(&"0".to_string(), 3);
manager.push(&"0".to_string(), 4);
assert_eq!(manager.active_blocks(), 2);
assert_eq!(manager.partial_blocks.len(), 1);
// Step 2: Add request 1 with tokens [0, 1, 2, 3, 4, 5, 6]
manager.add_request("1".to_string(), to_sequence(vec![0, 1, 2, 3, 4, 5, 6]));
assert_eq!(manager.active_blocks(), 3);
// Check that only one key is FullBlock with both requests sharing it
let mut full_block_count = 0;
let mut shared_block_requests = None;
for (block, requests) in &manager.unique_blocks {
if let UniqueBlock::FullBlock(_) = block {
full_block_count += 1;
if requests.len() == 2 {
shared_block_requests = Some(requests.clone());
}
}
}
assert_eq!(full_block_count, 1);
assert!(shared_block_requests.is_some());
let shared_requests = shared_block_requests.unwrap();
assert!(shared_requests.contains("0"));
assert!(shared_requests.contains("1"));
let new_blocks = manager.new_blocks(&to_sequence(vec![0, 1, 2, 3, 4, 5]));
assert_eq!(new_blocks, 1);
// Step 3: Free request 1
manager.free(&"1".to_string());
assert_eq!(manager.active_blocks(), 2);
// Step 4: Free request 0
manager.free(&"0".to_string());
assert_eq!(manager.active_blocks(), 0);
assert_eq!(manager.unique_blocks.len(), 0);
assert_eq!(manager.partial_blocks.len(), 0);
assert_eq!(manager.active_seqs.len(), 0);
}
#[test]
fn test_active_sequences_multi_worker() {
let block_size = 4;
let worker_ids = vec![0, 1, 2];
let mut manager = ActiveSequencesMultiWorker::new(block_size, worker_ids);
let to_sequence =
|tokens: Vec<u32>| Tokens::from(tokens).into_sequence(block_size as u32, None);
// Send request [0, 1, 2, 3] to worker 0
manager.add_request("req0".to_string(), to_sequence(vec![0, 1, 2, 3]), 0);
// Send request [0, 1, 2] to worker 1, then push 3 and push 4
manager.add_request("req1".to_string(), to_sequence(vec![0, 1, 2]), 1);
manager.push(&"req1".to_string(), 3);
manager.push(&"req1".to_string(), 4);
// Send request [0, 1, 2] to worker 2
manager.add_request("req2".to_string(), to_sequence(vec![0, 1, 2]), 2);
// Check new_blocks on tokens [0, 1, 2, 3, 4]
let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4]));
assert_eq!(new_blocks_map[&0], 1); // Worker 0 would have 1 new block
assert_eq!(new_blocks_map[&1], 1); // Worker 1 would have 1 new block
assert_eq!(new_blocks_map[&2], 2); // Worker 2 would have 2 new blocks
manager.update_workers(vec![0, 1]);
manager.update_workers(vec![0, 1, 3]);
let new_blocks_map = manager.new_blocks(to_sequence(vec![0, 1, 2, 3, 4]));
assert_eq!(new_blocks_map.len(), 3);
assert_eq!(new_blocks_map[&3], 2);
}
}
......@@ -46,8 +46,9 @@
//! implementation of the main block manager.
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost, UniqueBlock};
use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost};
use crate::mocker::sequence::ActiveSequence;
use crate::tokens::blocks::UniqueBlock;
use derive_getters::Getters;
use std::collections::{HashMap, HashSet};
use tokio::sync::mpsc;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment