Unverified Commit 84e71e27 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: predictive active blocks for routing without load metrics (#1731)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarAlec <35311602+alec-flowers@users.noreply.github.com>
parent ffccc722
......@@ -83,7 +83,7 @@ use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::time::Duration as StdDuration;
use dynamo_llm::kv_router::protocols::ForwardPassMetrics;
use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics};
use dynamo_llm::kv_router::scheduler::Endpoint;
use dynamo_llm::kv_router::scoring::ProcessedEndpoints;
......@@ -449,7 +449,10 @@ impl PrometheusMetrics {
// Update per-worker metrics
for (worker_id, endpoint) in processed.endpoints.iter() {
let worker_id = worker_id.to_string();
let metrics = endpoint.data.clone();
let load_metrics = endpoint.data.clone();
let LoadMetrics::EngineLoadMetrics(metrics) = load_metrics else {
panic!("Can only update with ForwardPassMetrics");
};
self.set_worker_gauge(
&self.kv_blocks_active,
......@@ -602,7 +605,7 @@ pub fn postprocess_metrics(
e.id().ok().map(|id| Endpoint {
name: format!("worker-{id}"),
subject: e.subject.clone(),
data: m.clone(),
data: LoadMetrics::EngineLoadMetrics(m.clone()),
})
})
.collect();
......
......@@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> {
let selector = Box::new(CustomWorkerSelector::default());
let router = KvRouter::new(component.clone(), args.block_size, Some(selector)).await?;
let router = KvRouter::new(component.clone(), args.block_size, Some(selector), true).await?;
let router = Ingress::for_engine(Arc::new(router))?;
component
......
......@@ -3392,14 +3392,8 @@ index cafd8150b..6a5e45b4e 100644
+ num_requests_waiting: int
+ gpu_cache_usage_perc: float
+ gpu_prefix_cache_hit_rate: float
+ spec_decode_draft_acceptance_rate: Optional[float] = None
+ spec_decode_system_efficiency: Optional[float] = None
+ spec_decode_draft_tokens: Optional[int] = None
+ spec_decode_emitted_tokens: Optional[int] = None
+ spec_decode_accepted_tokens: Optional[int] = None
+ spec_decode_num_spec_tokens: Optional[int] = None
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index f058b1329..2fdb5b8bf 100644
index f058b1329..fd5610a3c 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -1,4 +1,17 @@
......@@ -3460,16 +3454,25 @@ index f058b1329..2fdb5b8bf 100644
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
@@ -48,6 +66,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -48,6 +66,17 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import Device, deprecate_kwargs
+from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback
+from vllm.distributed.device_communicators.nixl import NixlMetadata
+
+# Import ForwardPassMetrics and related classes from dynamo
+try:
+ from dynamo.llm import ForwardPassMetrics, WorkerStats, KvStats
+except ImportError:
+ # Fallback if dynamo imports are not available
+ ForwardPassMetrics = None
+ WorkerStats = None
+ KvStats = None
logger = init_logger(__name__)
@@ -93,6 +113,7 @@ class MQLLMEngineClient(EngineClient):
@@ -93,6 +122,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None
# Get the configs.
......@@ -3477,7 +3480,7 @@ index f058b1329..2fdb5b8bf 100644
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
@@ -117,6 +138,10 @@ class MQLLMEngineClient(EngineClient):
@@ -117,6 +147,10 @@ class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
......@@ -3488,7 +3491,7 @@ index f058b1329..2fdb5b8bf 100644
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -131,8 +156,27 @@ class MQLLMEngineClient(EngineClient):
@@ -131,8 +165,27 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
......@@ -3516,7 +3519,7 @@ index f058b1329..2fdb5b8bf 100644
@staticmethod
def is_unsupported_config(vllm_config: VllmConfig):
# Pipeline parallel not yet supported
@@ -182,6 +226,61 @@ class MQLLMEngineClient(EngineClient):
@@ -182,6 +235,76 @@ class MQLLMEngineClient(EngineClient):
except Exception as e:
self._set_errored(e)
......@@ -3553,13 +3556,28 @@ index f058b1329..2fdb5b8bf 100644
+ if self.metrics_publisher is not None and isinstance(
+ metrics, KvMetrics
+ ):
+ self.metrics_publisher.publish(metrics.request_active_slots,
+ metrics.request_total_slots,
+ metrics.kv_active_blocks,
+ metrics.kv_total_blocks,
+ metrics.num_requests_waiting,
+ metrics.gpu_cache_usage_perc,
+ metrics.gpu_prefix_cache_hit_rate)
+ # Construct structured metrics objects
+ worker_stats = WorkerStats(
+ request_active_slots=metrics.request_active_slots,
+ request_total_slots=metrics.request_total_slots,
+ num_requests_waiting=metrics.num_requests_waiting,
+ data_parallel_rank=None
+ )
+
+ kv_stats = KvStats(
+ kv_active_blocks=metrics.kv_active_blocks,
+ kv_total_blocks=metrics.kv_total_blocks,
+ gpu_cache_usage_perc=metrics.gpu_cache_usage_perc,
+ gpu_prefix_cache_hit_rate=metrics.gpu_prefix_cache_hit_rate
+ )
+
+ forward_pass_metrics = ForwardPassMetrics(
+ worker_stats=worker_stats,
+ kv_stats=kv_stats,
+ spec_decode_stats=None
+ )
+
+ self.metrics_publisher.publish(forward_pass_metrics)
+ logger.debug("Metrics successful.")
+
+ # TODO: Investigate sending whole stats object
......@@ -3578,7 +3596,7 @@ index f058b1329..2fdb5b8bf 100644
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
@@ -250,7 +349,7 @@ class MQLLMEngineClient(EngineClient):
@@ -250,7 +373,7 @@ class MQLLMEngineClient(EngineClient):
# Put each output into the appropriate queue.
elif isinstance(
request_outputs,
......@@ -3587,7 +3605,7 @@ index f058b1329..2fdb5b8bf 100644
self._add_output(request_outputs)
else:
for request_output in request_outputs:
@@ -261,7 +360,7 @@ class MQLLMEngineClient(EngineClient):
@@ -261,7 +384,7 @@ class MQLLMEngineClient(EngineClient):
def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse,
......@@ -3596,7 +3614,7 @@ index f058b1329..2fdb5b8bf 100644
queue = self.output_queues.get(request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
@@ -283,12 +382,25 @@ class MQLLMEngineClient(EngineClient):
@@ -283,12 +406,25 @@ class MQLLMEngineClient(EngineClient):
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
......@@ -3622,7 +3640,7 @@ index f058b1329..2fdb5b8bf 100644
def close(self):
"""Destroy the ZeroMQ Context."""
@@ -298,6 +410,8 @@ class MQLLMEngineClient(EngineClient):
@@ -298,6 +434,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
......@@ -3631,7 +3649,7 @@ index f058b1329..2fdb5b8bf 100644
if self.output_loop is not None:
self.output_loop.cancel()
@@ -420,6 +534,9 @@ class MQLLMEngineClient(EngineClient):
@@ -420,6 +558,9 @@ class MQLLMEngineClient(EngineClient):
"""
if self._errored_with is not None:
raise self._errored_with
......@@ -3641,7 +3659,7 @@ index f058b1329..2fdb5b8bf 100644
@property
def is_running(self) -> bool:
@@ -478,6 +595,7 @@ class MQLLMEngineClient(EngineClient):
@@ -478,6 +619,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
......@@ -3649,7 +3667,7 @@ index f058b1329..2fdb5b8bf 100644
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
@@ -507,7 +625,8 @@ class MQLLMEngineClient(EngineClient):
@@ -507,7 +649,8 @@ class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
......@@ -3659,7 +3677,7 @@ index f058b1329..2fdb5b8bf 100644
@overload
def encode(
@@ -591,6 +710,7 @@ class MQLLMEngineClient(EngineClient):
@@ -591,6 +734,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
......@@ -3667,7 +3685,7 @@ index f058b1329..2fdb5b8bf 100644
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -636,6 +756,12 @@ class MQLLMEngineClient(EngineClient):
@@ -636,6 +780,12 @@ class MQLLMEngineClient(EngineClient):
else:
lp_bytes = None
......@@ -3680,7 +3698,7 @@ index f058b1329..2fdb5b8bf 100644
request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
@@ -645,11 +771,11 @@ class MQLLMEngineClient(EngineClient):
@@ -645,11 +795,11 @@ class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
......@@ -3694,7 +3712,7 @@ index f058b1329..2fdb5b8bf 100644
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
@@ -740,3 +866,22 @@ class MQLLMEngineClient(EngineClient):
@@ -740,3 +890,22 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
......
......@@ -15,6 +15,8 @@ See the License for the specific language governing permissions and
limitations under the License.
-->
>[!NOTE]
>This information is temporary and will change soon.
# KV Cache Routing
This documentation explains how Key-Value (KV) cache routing works in Dynamo, providing optimized inference for large language models by intelligently directing requests to workers with the most relevant cached data while simultaneously load balancing based on utilization metrics sent by the workers.
......
......@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
Usage:
```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=2.0] [--kv-gpu-cache-usage-weight=1.0] [--kv-waiting-requests-weight=1.0] [--verbosity (-v|-vv)]
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--verbosity (-v|-vv)]
```
Example: `dynamo run Qwen/Qwen3-0.6B`
......@@ -201,6 +201,8 @@ The only difference from the distributed system above is `--router-mode kv`. The
For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing.
The argument `--kv-overlap-score-weight` sets the amount weighting on overlaps with prefix caches, which directly contributes to the prefill cost, so a large weight is expected to yield a better TTFT (at the expense of worse ITL). When this is set 0, we do not consider the prefix caches at all (falling back to pure load balancing behavior on the active blocks), in which case we do not require the backend engines to emit any KV events. The argument `--router-temperature` sets the temperature when randomly selecting the workers to route to via softmax sampling on the router cost logits, setting it to 0 recovers the deterministic behavior where the min logit is picked.
## Full usage details
`dynamo run` executes `dynamo-run`. `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features.
......
......@@ -15,6 +15,9 @@ See the License for the specific language governing permissions and
limitations under the License.
-->
>[!NOTE]
>This information is temporary and will change soon.
# KV Router Performance Tuning
## Overview
......
......@@ -110,20 +110,23 @@ pub struct Flags {
#[arg(long, default_value = "round-robin")]
pub router_mode: RouterMode,
/// Maximum number of batched tokens for KV routing
/// Needed for informing the KV router
/// TODO: derive from vllm args
/// NOTE: this is not actually used for now
#[arg(long, default_value = "8192")]
pub max_num_batched_tokens: Option<u32>,
/// KV Router: Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 2.0
#[arg(long)]
pub kv_overlap_score_weight: Option<f64>,
/// KV Router: Weight for GPU cache usage in worker selection.
/// Higher values avoid workers with nearly full KV caches. Default: 1.0
#[arg(long)]
pub kv_gpu_cache_usage_weight: Option<f64>,
/// KV Router: Weight for waiting requests in worker selection.
/// Higher values avoid workers with queued requests. Default: 1.0
/// KV Router: Temperature for worker sampling via softmax.
/// Higher values promote more randomness, and 0 fallbacks to deterministic.
/// Default: 0.5
#[arg(long)]
pub kv_waiting_requests_weight: Option<f64>,
pub router_temperature: Option<f64>,
/// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4).
......@@ -211,8 +214,8 @@ impl Flags {
self.router_mode.into(),
KvRouterConfig::new(
self.kv_overlap_score_weight,
self.kv_gpu_cache_usage_weight,
self.kv_waiting_requests_weight,
self.router_temperature,
self.max_num_batched_tokens,
),
)
}
......
......@@ -26,7 +26,14 @@ from vllm.entrypoints.openai.api_server import (
)
from vllm.inputs import TokensPrompt
from dynamo.llm import ModelType, WorkerMetricsPublisher, register_llm
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelType,
WorkerMetricsPublisher,
WorkerStats,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -70,15 +77,29 @@ class RequestHandler:
self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
self.metrics_publisher.publish(
0, # request_active_slots
1024, # request_total_slots
0, # kv_active_blocks
1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
# Create the structured metrics objects
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=1024,
num_requests_waiting=0,
data_parallel_rank=None,
)
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=1024,
gpu_cache_usage_perc=0.0,
gpu_prefix_cache_hit_rate=0.0,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats, kv_stats=kv_stats, spec_decode_stats=None
)
# Publish the metrics as a single object
self.metrics_publisher.publish(metrics)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
......
......@@ -72,7 +72,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Client>()?;
m.add_class::<EtcdClient>()?;
m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::disagg_router::DisaggregatedRouter>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?;
......
......@@ -29,51 +29,6 @@ use tracing;
use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig};
#[pyclass]
pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>,
}
#[pymethods]
impl KvRouter {
#[new]
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
};
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner = llm_rs::kv_router::KvRouter::new(
component.inner.clone(),
kv_block_size as u32,
None,
)
.await
.map_err(to_pyerr)?;
Ok(Self {
inner: Arc::new(inner),
})
})
}
fn schedule<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
lora_id: u64,
) -> PyResult<Bound<'p, PyAny>> {
let router = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker_id = router
.schedule(&token_ids, lora_id)
.await
.map_err(to_pyerr)?;
Ok(worker_id)
})
}
}
#[pyfunction]
pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> {
if kv_block_size == 0 {
......@@ -617,25 +572,34 @@ impl KvMetricsAggregator {
fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
// TODO: update EndpointKvMetrics to match the new ForwardPassMetrics struct
let endpoints = self.inner.get_endpoints();
let load_avg = endpoints.load_avg;
let load_std = endpoints.load_std;
let endpoint_kv_metrics = endpoints
.endpoints
.iter()
.map(|(worker_id, endpoint)| EndpointKvMetrics {
worker_id: *worker_id,
request_active_slots: endpoint.data.worker_stats.request_active_slots,
request_total_slots: endpoint.data.worker_stats.request_total_slots,
kv_active_blocks: endpoint.data.kv_stats.kv_active_blocks,
kv_total_blocks: endpoint.data.kv_stats.kv_total_blocks,
num_requests_waiting: endpoint.data.worker_stats.num_requests_waiting,
gpu_cache_usage_perc: endpoint.data.kv_stats.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: endpoint.data.kv_stats.gpu_prefix_cache_hit_rate,
.into_iter()
.map(|(worker_id, endpoint)| {
let metrics = endpoint.data;
let LoadMetrics::EngineLoadMetrics(fwd_pass_metrics) = metrics else {
panic!("Endpoints do not contain forward pass metrics.");
};
EndpointKvMetrics {
worker_id,
request_active_slots: fwd_pass_metrics.worker_stats.request_active_slots,
request_total_slots: fwd_pass_metrics.worker_stats.request_total_slots,
kv_active_blocks: fwd_pass_metrics.kv_stats.kv_active_blocks,
kv_total_blocks: fwd_pass_metrics.kv_stats.kv_total_blocks,
num_requests_waiting: fwd_pass_metrics.worker_stats.num_requests_waiting,
gpu_cache_usage_perc: fwd_pass_metrics.kv_stats.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: fwd_pass_metrics.kv_stats.gpu_prefix_cache_hit_rate,
}
})
.collect();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(AggregatedMetrics {
endpoints: endpoint_kv_metrics,
load_avg: endpoints.load_avg,
load_std: endpoints.load_std,
load_avg,
load_std,
})
})
}
......
......@@ -272,25 +272,6 @@ class Client:
"""
...
class KvRouter:
"""
A router will determine which worker should handle a given request.
"""
...
def __init__(self, drt: DistributedRuntime, component: Component) -> None:
"""
Create a `KvRouter` object that is associated with the `component`
"""
def schedule(self, token_ids: List[int], lora_id: int) -> int:
"""
Return the worker id that should handle the given token ids,
exception will be raised if there is no worker available.
"""
...
class DisaggregatedRouter:
"""
A router that determines whether to perform prefill locally or remotely based on
......
......@@ -32,7 +32,6 @@ from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouter as KvRouter
from dynamo._core import KvStats as KvStats
from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores
......
......@@ -212,8 +212,20 @@ impl ModelManager {
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> {
// Determine if we should use KV events based on overlap score weight
let use_kv_events = kv_router_config
.as_ref()
.map(|config| config.overlap_score_weight > 0.0)
.unwrap_or(false);
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new(component.clone(), kv_cache_block_size, Some(selector)).await?;
let chooser = KvRouter::new(
component.clone(),
kv_cache_block_size,
Some(selector),
use_kv_events,
)
.await?;
let new_kv_chooser = Arc::new(chooser);
self.kv_choosers
.lock()
......
......@@ -23,11 +23,12 @@ pub mod publisher;
pub mod recorder;
pub mod scheduler;
pub mod scoring;
pub mod sequence;
use crate::{
kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
metrics_aggregator::KvMetricsAggregator,
metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
......@@ -58,25 +59,20 @@ pub trait WorkerSelector {
/// KV Router configuration parameters
#[derive(Debug, Clone)]
pub struct KvRouterConfig {
/// Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 2.0
pub overlap_score_weight: f64,
/// Weight for GPU cache usage in worker selection.
/// Higher values avoid workers with nearly full KV caches. Default: 1.0
pub gpu_cache_usage_weight: f64,
pub router_temperature: f64,
/// Weight for waiting requests in worker selection.
/// Higher values avoid workers with queued requests. Default: 1.0
pub waiting_requests_weight: f64,
// note: this is not actually used for now
pub max_num_batched_tokens: u32,
}
impl Default for KvRouterConfig {
fn default() -> Self {
Self {
overlap_score_weight: 1.0,
gpu_cache_usage_weight: 1.0,
waiting_requests_weight: 1.0,
router_temperature: 0.5,
max_num_batched_tokens: 8192,
}
}
}
......@@ -86,16 +82,15 @@ impl KvRouterConfig {
/// If a weight is None, the default value will be used.
pub fn new(
overlap_score_weight: Option<f64>,
gpu_cache_usage_weight: Option<f64>,
waiting_requests_weight: Option<f64>,
temperature: Option<f64>,
max_num_batched_tokens: Option<u32>,
) -> Self {
let default = Self::default();
Self {
overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
gpu_cache_usage_weight: gpu_cache_usage_weight
.unwrap_or(default.gpu_cache_usage_weight),
waiting_requests_weight: waiting_requests_weight
.unwrap_or(default.waiting_requests_weight),
router_temperature: temperature.unwrap_or(default.router_temperature),
max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens),
}
}
}
......@@ -103,7 +98,7 @@ impl KvRouterConfig {
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
pub struct KvRouter {
indexer: KvIndexer,
indexer: Option<KvIndexer>,
scheduler: KvScheduler,
block_size: u32,
}
......@@ -113,16 +108,19 @@ impl KvRouter {
component: Component,
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
use_kv_events: bool,
) -> Result<Self> {
let cancellation_token = component
.drt()
.primary_lease()
.expect("Cannot KV route static workers")
.primary_token();
tracing::info!("KV Routing initialized");
let metrics_aggregator =
KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
EndpointCollector::new(component.clone(), cancellation_token.clone()).await;
let maybe_indexer =
use_kv_events.then(|| KvIndexer::new(cancellation_token.clone(), block_size));
let scheduler = KvScheduler::start(
component.namespace().clone(),
block_size,
......@@ -133,6 +131,7 @@ impl KvRouter {
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different.
if let Some(ref indexer) = maybe_indexer {
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_events_tx = indexer.event_sender();
......@@ -148,35 +147,31 @@ impl KvRouter {
}
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e);
tracing::debug!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
}
}
});
}
tracing::info!("KV Routing initialized");
Ok(Self {
indexer: maybe_indexer,
scheduler,
indexer,
block_size,
})
}
// [TODO] indexer needs to take 'lora_id' as parameter
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
// Extracting part of the code in KvRouter::generate() for only
// the decision making part, routing is done by the caller
let isl_tokens = token_ids.len();
let overlap_scores = self
.indexer
.find_matches_for_request(token_ids.as_slice())
.await?;
tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
}
/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returned overlap amount is in number of blocks.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> {
/// Now also takes context_id for request tracking
async fn find_best_match(
&self,
context_id: &str,
tokens: &[u32],
) -> anyhow::Result<(i64, u32)> {
let isl_tokens = tokens.len();
let block_size = self.block_size;
......@@ -187,13 +182,38 @@ impl KvRouter {
.into_iter()
.map(|block| LocalBlockHash(block.block_hash()))
.collect();
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let worker_id = self
let overlap_scores = match &self.indexer {
Some(indexer) => indexer.find_matches(local_block_hashes).await?,
None => Default::default(), // Returns empty/default instance
};
let best_worker_id = self
.scheduler
.schedule(overlap_scores.clone(), isl_tokens)
.schedule(
context_id.to_string(),
isl_tokens,
block_size,
tokens,
overlap_scores.clone(),
)
.await?;
let overlap_amount = overlap_scores.scores.get(&worker_id).copied().unwrap_or(0);
Ok((worker_id, overlap_amount))
let overlap_amount = overlap_scores
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0);
Ok((best_worker_id, overlap_amount))
}
/// Push a token to a specific request's sequence
pub async fn push(&self, request_id: &String, token: u32) {
self.scheduler.push(request_id, token).await
}
/// Free all blocks associated with a request
pub async fn free(&self, request_id: &String) {
self.scheduler.free(request_id).await
}
/// Get the block size this router was configured with
......@@ -202,6 +222,7 @@ impl KvRouter {
}
}
// NOTE: this would not be usable for now, should deprecate
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
async fn generate(
......@@ -209,7 +230,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts();
let (worker_id, _) = self.find_best_match(&request.tokens).await?;
let (worker_id, _) = self.find_best_match(ctx.id(), &request.tokens).await?;
let response = RouterResponse { worker_id };
let response = Annotated::from_data(response);
......@@ -243,13 +264,40 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
match self.inner.client.instance_source.as_ref() {
InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => {
let (instance_id, overlap_amount) =
self.chooser.find_best_match(&request.token_ids).await?;
// Extract context ID for request tracking
let context_id = request.context().id().to_string();
let (instance_id, overlap_amount) = self
.chooser
.find_best_match(&context_id, &request.token_ids)
.await?;
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
self.inner.direct(updated_request, instance_id).await
// Get the response stream from the worker
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens
let stream_context = response_stream.context();
let chooser = self.chooser.clone();
let request_id = context_id.clone();
let wrapped_stream = Box::pin(async_stream::stream! {
while let Some(item) = response_stream.next().await {
// Track tokens if they exist in the response
if let Some(ref output) = item.data {
for token_id in &output.token_ids {
chooser.push(&request_id, *token_id).await;
}
}
yield item;
}
chooser.free(&request_id).await;
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
}
}
......
......@@ -15,7 +15,7 @@
use std::sync::Once;
pub use crate::kv_router::protocols::ForwardPassMetrics;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::KV_METRICS_ENDPOINT;
use crate::kv_router::scheduler::Endpoint;
......@@ -28,6 +28,37 @@ use tokio_util::sync::CancellationToken;
static METRICS_WAITING_MESSAGE: Once = Once::new();
static METRICS_FOUND_MESSAGE: Once = Once::new();
pub struct EndpointCollector {
pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
}
impl EndpointCollector {
pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
let (watch_tx, watch_rx) = watch::channel(ProcessedEndpoints::default());
tokio::spawn(collect_endpoints_task(
component.clone(),
watch_tx,
cancellation_token.clone(),
"generate".to_string(),
));
Self {
service_name: component.service_name(),
endpoints_rx: watch_rx,
}
}
pub fn get_endpoints(&self) -> ProcessedEndpoints {
self.endpoints_rx.borrow().clone()
}
pub fn endpoints_watcher(&self) -> watch::Receiver<ProcessedEndpoints> {
self.endpoints_rx.clone()
}
}
pub struct KvMetricsAggregator {
pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
......@@ -41,6 +72,7 @@ impl KvMetricsAggregator {
component.clone(),
watch_tx,
cancellation_token.clone(),
KV_METRICS_ENDPOINT.to_string(),
));
Self {
......@@ -93,12 +125,16 @@ pub async fn collect_endpoints_task(
component: Component,
watch_tx: watch::Sender<ProcessedEndpoints>,
cancel: CancellationToken,
subject: String,
) {
let backoff_delay = Duration::from_millis(100);
let scrape_timeout = Duration::from_millis(300);
let endpoint = component.endpoint(KV_METRICS_ENDPOINT);
let endpoint = component.endpoint(&subject);
let service_subject = endpoint.subject();
// Keep track of the last sent value to avoid unnecessary updates
let mut last_sent: Option<ProcessedEndpoints> = None;
loop {
tokio::select! {
_ = cancel.cancelled() => {
......@@ -115,31 +151,59 @@ pub async fn collect_endpoints_task(
continue;
}
};
let endpoints: Vec<Endpoint> = unfiltered_endpoints
let endpoints: Vec<Endpoint> = if subject == KV_METRICS_ENDPOINT {
// Original filtering behavior
unfiltered_endpoints
.into_iter()
.filter(|s| s.data.is_some())
.filter_map(|s|
match s.data.unwrap().decode::<ForwardPassMetrics>() {
Ok(data) => Some(Endpoint {
.filter_map(|s| {
s.data?
.decode::<ForwardPassMetrics>()
.map(|data| Endpoint {
name: s.name,
subject: s.subject,
data,
}),
Err(e) => {
tracing::debug!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e);
None
}
}
)
.collect();
data: LoadMetrics::EngineLoadMetrics(data),
})
.inspect_err(|e| {
tracing::warn!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e);
})
.ok()
})
.collect()
} else {
// No filtering - just use default LoadMetrics
unfiltered_endpoints
.into_iter()
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: LoadMetrics::default(),
})
.collect()
};
tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len());
let processed = ProcessedEndpoints::new(endpoints);
if watch_tx.send(processed).is_err() {
tracing::trace!("failed to send processed endpoints; shutting down");
// Only send if different from last sent value
// This is necessary because the watch channel does not track changes
// https://docs.rs/tokio/latest/tokio/sync/watch/struct.Receiver.html#method.has_changed
let should_send = match &last_sent {
Some(last) => last != &processed,
None => true,
};
if should_send {
tracing::trace!("Endpoints changed, sending update for service: {service_subject}");
if watch_tx.send(processed.clone()).is_err() {
tracing::error!("failed to send processed endpoints; shutting down");
break;
}
last_sent = Some(processed);
} else {
tracing::trace!("Endpoints unchanged, skipping update for service: {service_subject}");
}
}
}
}
......
......@@ -39,14 +39,14 @@ pub struct WorkerSelectionResult {
pub overlap_blocks: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct ForwardPassMetrics {
pub worker_stats: WorkerStats,
pub kv_stats: KvStats,
pub spec_decode_stats: Option<SpecDecodeStats>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct WorkerStats {
// https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models
pub data_parallel_rank: Option<u32>,
......@@ -55,7 +55,7 @@ pub struct WorkerStats {
pub num_requests_waiting: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct KvStats {
pub kv_active_blocks: u64,
pub kv_total_blocks: u64,
......@@ -65,7 +65,34 @@ pub struct KvStats {
pub gpu_prefix_cache_hit_rate: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct PredictiveLoadMetrics {
pub kv_active_blocks: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum LoadMetrics {
EngineLoadMetrics(ForwardPassMetrics),
PredictiveLoadMetrics(PredictiveLoadMetrics),
}
impl LoadMetrics {
pub fn kv_active_blocks(&self) -> u64 {
match self {
LoadMetrics::EngineLoadMetrics(metrics) => metrics.kv_stats.kv_active_blocks,
LoadMetrics::PredictiveLoadMetrics(metrics) => metrics.kv_active_blocks,
}
}
}
impl Default for LoadMetrics {
fn default() -> Self {
LoadMetrics::PredictiveLoadMetrics(PredictiveLoadMetrics::default())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct SpecDecodeStats {
pub num_spec_tokens: Option<u32>,
pub num_drafts: Option<u32>,
......
This diff is collapsed.
......@@ -20,7 +20,7 @@ use std::collections::HashMap;
use crate::kv_router::scheduler::Endpoint;
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct ProcessedEndpoints {
pub endpoints: HashMap<i64, Endpoint>,
pub load_avg: f64,
......@@ -32,7 +32,7 @@ impl ProcessedEndpoints {
// compute some basic statistics
let load_values: Vec<f64> = endpoints
.iter()
.map(|x| x.data.kv_stats.kv_active_blocks as f64)
.map(|endpoint| endpoint.data.kv_active_blocks() as f64)
.collect();
let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64;
let variance = load_values
......@@ -50,4 +50,15 @@ impl ProcessedEndpoints {
load_std,
}
}
pub fn worker_ids(&self) -> Vec<i64> {
self.endpoints.keys().copied().collect()
}
pub fn active_blocks(&self) -> HashMap<i64, usize> {
self.endpoints
.iter()
.map(|(&worker_id, endpoint)| (worker_id, endpoint.data.kv_active_blocks() as usize))
.collect()
}
}
This diff is collapsed.
......@@ -46,8 +46,9 @@
//! implementation of the main block manager.
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost, UniqueBlock};
use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost};
use crate::mocker::sequence::ActiveSequence;
use crate::tokens::blocks::UniqueBlock;
use derive_getters::Getters;
use std::collections::{HashMap, HashSet};
use tokio::sync::mpsc;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment