Unverified Commit fc229004 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: Remove ZmqKvEventListener binding and rework standalone TRT-LLM example...


chore: Remove ZmqKvEventListener binding and rework standalone TRT-LLM example to use native Python ZMQ (#6164)
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent c7986b35
...@@ -11,10 +11,11 @@ from contextlib import asynccontextmanager ...@@ -11,10 +11,11 @@ from contextlib import asynccontextmanager
import numpy as np import numpy as np
import uvicorn import uvicorn
import zmq import zmq
import zmq.asyncio
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from dynamo._core import RadixTree, ZmqKvEventListener from dynamo._core import RadixTree
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -110,16 +111,23 @@ class KvRouter: ...@@ -110,16 +111,23 @@ class KvRouter:
) )
for i in range(num_workers) for i in range(num_workers)
] ]
self.async_context = zmq.asyncio.Context()
self.kv_listeners = [ self.kv_listeners = [
ZmqKvEventListener( self._create_kv_listener(base_kv_events_port + i)
f"tcp://localhost:{base_kv_events_port + i}", "", block_size
)
for i in range(num_workers) for i in range(num_workers)
] ]
self.background_tasks: list[asyncio.Task] = [] self.background_tasks: list[asyncio.Task] = []
logger.info("Router initialized") logger.info("Router initialized")
def _create_kv_listener(self, port: int) -> zmq.asyncio.Socket:
"""Create an async ZMQ SUB socket for receiving KV cache events."""
sock = self.async_context.socket(zmq.SUB)
sock.connect(f"tcp://localhost:{port}")
sock.setsockopt(zmq.SUBSCRIBE, b"")
sock.setsockopt(zmq.RCVTIMEO, 1)
return sock
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# Background Tasks # Background Tasks
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
...@@ -153,11 +161,11 @@ class KvRouter: ...@@ -153,11 +161,11 @@ class KvRouter:
async def _poll_worker_kv_events(self, worker_id: int): async def _poll_worker_kv_events(self, worker_id: int):
"""Poll KV events for a single worker and update RadixTree.""" """Poll KV events for a single worker and update RadixTree."""
sock = self.kv_listeners[worker_id]
while True: while True:
try: try:
events: list[str] = await self.kv_listeners[worker_id].get_events() event_bytes = await sock.recv(zmq.NOBLOCK)
for event_str in events: event = json.loads(event_bytes)
event = json.loads(event_str)
dump_kv_event(worker_id, event) dump_kv_event(worker_id, event)
self.radix_tree.apply_event( self.radix_tree.apply_event(
worker_id, json.dumps(event).encode("utf-8") worker_id, json.dumps(event).encode("utf-8")
...@@ -250,8 +258,11 @@ class KvRouter: ...@@ -250,8 +258,11 @@ class KvRouter:
for listener in self.load_listeners: for listener in self.load_listeners:
listener.close() listener.close()
for listener in self.kv_listeners:
listener.close()
self.context.term() self.context.term()
self.async_context.term()
logger.info("KvRouter shutdown completed") logger.info("KvRouter shutdown completed")
......
...@@ -10,15 +10,16 @@ if "PYTHONHASHSEED" not in os.environ: ...@@ -10,15 +10,16 @@ if "PYTHONHASHSEED" not in os.environ:
os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python") os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
import asyncio import asyncio
import json
import logging import logging
import time
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, Optional
import msgpack
import zmq import zmq
from tensorrt_llm import LLM from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.llmapi import KvCacheConfig
from dynamo.llm import compute_block_hash_for_seq_py
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024 DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
...@@ -87,7 +88,7 @@ class MetricsPublisher: ...@@ -87,7 +88,7 @@ class MetricsPublisher:
class KvEventsPublisher: class KvEventsPublisher:
"""Publishes KV cache events over ZMQ.""" """Publishes KV cache events as KvCacheEvent JSON over ZMQ."""
def __init__(self, port: int, block_size: int): def __init__(self, port: int, block_size: int):
self.context = zmq.Context() self.context = zmq.Context()
...@@ -95,7 +96,7 @@ class KvEventsPublisher: ...@@ -95,7 +96,7 @@ class KvEventsPublisher:
self.socket.bind(f"tcp://*:{port}") self.socket.bind(f"tcp://*:{port}")
self.block_size = block_size self.block_size = block_size
self.partial_block_hashes: set[int] = set() self.partial_block_hashes: set[int] = set()
self.sequence_number = 0 self.next_event_id = 0
def publish_stored( def publish_stored(
self, self,
...@@ -104,34 +105,46 @@ class KvEventsPublisher: ...@@ -104,34 +105,46 @@ class KvEventsPublisher:
parent_hash: int | None, parent_hash: int | None,
block_mm_infos: list[dict | None] | None, block_mm_infos: list[dict | None] | None,
): ):
"""Publish a BlockStored event. """Publish a KvCacheEvent with stored blocks.
Args: Computes tokens_hash per block using compute_block_hash_for_seq_py
block_hashes: List of block hashes being stored. (including MM info when present) and publishes as KvCacheEvent JSON.
token_ids: All token IDs across the blocks.
parent_hash: Hash of the parent block (if any).
block_mm_infos: Per-block multimodal info list. Each element corresponds
to a block and is either the mm_info dict (for blocks containing
image tokens) or None (for text-only blocks).
""" """
event = { # Compute tokens_hash per block (MM-aware when block_mm_infos provided)
"type": "BlockStored", tokens_hashes = compute_block_hash_for_seq_py(
"block_hashes": [to_unsigned_u64(h) for h in block_hashes], token_ids, self.block_size, block_mm_infos
"token_ids": token_ids, )
"block_size": self.block_size,
}
if parent_hash is not None:
event["parent_block_hash"] = to_unsigned_u64(parent_hash)
if block_mm_infos is not None: blocks = []
event["block_mm_infos"] = block_mm_infos for i, ext_hash in enumerate(block_hashes):
block_data = {
"block_hash": to_unsigned_u64(ext_hash),
"tokens_hash": tokens_hashes[i],
}
mm_info = block_mm_infos[i] if block_mm_infos else None
if mm_info is not None:
block_data["mm_extra_info"] = mm_info
blocks.append(block_data)
self._send([event]) event = {
"event_id": self.next_event_id,
"data": {
"stored": {
"parent_hash": (
to_unsigned_u64(parent_hash)
if parent_hash is not None
else None
),
"blocks": blocks,
}
},
"dp_rank": 0,
}
self.next_event_id += 1
self._send(event)
def publish_removed(self, block_hashes: list[int]): def publish_removed(self, block_hashes: list[int]):
"""Publish a BlockRemoved event.""" """Publish a KvCacheEvent with removed blocks."""
# Filter out partial blocks
filtered = [] filtered = []
for h in block_hashes: for h in block_hashes:
if h in self.partial_block_hashes: if h in self.partial_block_hashes:
...@@ -139,21 +152,29 @@ class KvEventsPublisher: ...@@ -139,21 +152,29 @@ class KvEventsPublisher:
else: else:
filtered.append(to_unsigned_u64(h)) filtered.append(to_unsigned_u64(h))
if filtered: if not filtered:
self._send([{"type": "BlockRemoved", "block_hashes": filtered}]) return
def _send(self, events: list[dict]): event = {
"""Send events via ZMQ multipart message.""" "event_id": self.next_event_id,
batch = [time.time(), events, 0] "data": {
"removed": {
"block_hashes": filtered,
}
},
"dp_rank": 0,
}
self.next_event_id += 1
self._send(event)
def _send(self, event: dict):
"""Send a single KvCacheEvent as JSON over ZMQ."""
try: try:
payload = msgpack.packb(batch, use_bin_type=True) payload = json.dumps(event).encode("utf-8")
except Exception as e: except Exception as e:
logger.error(f"msgpack error: {e}") logger.error(f"JSON encode error: {e}")
return return
self.socket.send(payload)
seq_bytes = self.sequence_number.to_bytes(8, byteorder="big")
self.sequence_number += 1
self.socket.send_multipart([b"", seq_bytes, payload])
def close(self): def close(self):
self.socket.close() self.socket.close()
......
...@@ -169,7 +169,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -169,7 +169,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::OverlapScores>()?; m.add_class::<llm::kv::OverlapScores>()?;
m.add_class::<llm::kv::KvEventPublisher>()?; m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<llm::kv::RadixTree>()?; m.add_class::<llm::kv::RadixTree>()?;
m.add_class::<llm::kv::ZmqKvEventListener>()?;
m.add_class::<llm::lora::LoRADownloader>()?; m.add_class::<llm::lora::LoRADownloader>()?;
m.add_class::<http::HttpService>()?; m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpAsyncEngine>()?; m.add_class::<http::HttpAsyncEngine>()?;
......
...@@ -17,11 +17,15 @@ use tracing; ...@@ -17,11 +17,15 @@ use tracing;
use llm_rs::kv_router::KvPushRouter as RsKvPushRouter; use llm_rs::kv_router::KvPushRouter as RsKvPushRouter;
use llm_rs::kv_router::protocols::*; use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener}; use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks};
use llm_rs::protocols::common::timing::RequestTracker; use llm_rs::protocols::common::timing::RequestTracker;
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use serde_json::json; use serde_json::json;
fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> {
depythonize(obj).map_err(to_pyerr)
}
#[pyfunction] #[pyfunction]
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None))] #[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None))]
pub fn compute_block_hash_for_seq_py( pub fn compute_block_hash_for_seq_py(
...@@ -36,21 +40,12 @@ pub fn compute_block_hash_for_seq_py( ...@@ -36,21 +40,12 @@ pub fn compute_block_hash_for_seq_py(
)); ));
} }
// Convert Python block_mm_infos to Rust Vec<Option<BlockExtraInfo>> let mm_infos = block_mm_infos
let mm_infos_rust: Option<Vec<Option<BlockExtraInfo>>> = block_mm_infos
.as_ref() .as_ref()
.map(|infos_py| { .map(depythonize_block_mm_infos)
depythonize::<Vec<Option<BlockExtraInfo>>>(infos_py).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to convert block_mm_infos: {}",
e
))
})
})
.transpose()?; .transpose()?;
let hashes = let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32, mm_infos.as_deref());
compute_block_hash_for_seq(&tokens, kv_block_size as u32, mm_infos_rust.as_deref());
Ok(hashes.into_iter().map(|h| h.0).collect()) Ok(hashes.into_iter().map(|h| h.0).collect())
} }
...@@ -101,79 +96,6 @@ impl WorkerMetricsPublisher { ...@@ -101,79 +96,6 @@ impl WorkerMetricsPublisher {
} }
} }
/// A ZMQ-based key-value cache event listener that operates independently
/// of the dynamo runtime or event plane infrastructure.
#[pyclass]
pub(crate) struct ZmqKvEventListener {
event_receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<KvCacheEvent>>>,
shutdown_token: tokio_util::sync::CancellationToken,
}
#[pymethods]
impl ZmqKvEventListener {
#[new]
#[pyo3(signature = (zmq_endpoint, zmq_topic, kv_block_size))]
fn new(zmq_endpoint: String, zmq_topic: String, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<KvCacheEvent>();
let shutdown_token = tokio_util::sync::CancellationToken::new();
// Standalone listener needs its own event ID counter
let next_event_id = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
tokio::spawn(start_zmq_listener(
zmq_endpoint,
zmq_topic,
tx,
shutdown_token.clone(),
kv_block_size as u32,
next_event_id,
));
Ok(Self {
event_receiver: Arc::new(tokio::sync::Mutex::new(rx)),
shutdown_token,
})
})
}
fn get_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let receiver = self.event_receiver.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut rx = receiver.lock().await;
let mut events = Vec::new();
// Drain all available events
while let Ok(event) = rx.try_recv() {
events.push(event);
}
// Convert events to JSON strings
let json_events: Result<Vec<String>, _> =
events.iter().map(serde_json::to_string).collect();
match json_events {
Ok(json_strings) => Ok(json_strings),
Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to serialize events to JSON: {}",
e
))),
}
})
}
}
// manual shutdown needed as it's not tied to the dynamo DRT
impl Drop for ZmqKvEventListener {
fn drop(&mut self) {
self.shutdown_token.cancel();
}
}
#[pyclass] #[pyclass]
pub(crate) struct KvEventPublisher { pub(crate) struct KvEventPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>, inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
...@@ -243,17 +165,9 @@ impl KvEventPublisher { ...@@ -243,17 +165,9 @@ impl KvEventPublisher {
// Use shared monotonic event_id counter from the inner publisher // Use shared monotonic event_id counter from the inner publisher
let event_id = inner.next_event_id(); let event_id = inner.next_event_id();
// Convert Python block_mm_infos to Rust Vec<Option<BlockExtraInfo>> let mm_infos = block_mm_infos
let mm_infos_rust: Option<Vec<Option<BlockExtraInfo>>> = block_mm_infos
.as_ref() .as_ref()
.map(|infos_py| { .map(depythonize_block_mm_infos)
depythonize::<Vec<Option<BlockExtraInfo>>>(infos_py).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to convert block_mm_infos: {}",
e
))
})
})
.transpose()?; .transpose()?;
py.allow_threads(|| { py.allow_threads(|| {
...@@ -269,7 +183,7 @@ impl KvEventPublisher { ...@@ -269,7 +183,7 @@ impl KvEventPublisher {
&block_hashes_u64, &block_hashes_u64,
lora_id, lora_id,
&warning_count, &warning_count,
mm_infos_rust.as_deref(), mm_infos.as_deref(),
), ),
}), }),
dp_rank, dp_rank,
...@@ -862,12 +776,9 @@ impl KvRouter { ...@@ -862,12 +776,9 @@ impl KvRouter {
None None
}; };
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = if let Some(obj) = block_mm_infos let block_mm_infos = block_mm_infos
{ .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?) .transpose()?;
} else {
None
};
let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> = let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> =
if let Some(obj) = multi_modal_data { if let Some(obj) = multi_modal_data {
...@@ -962,12 +873,9 @@ impl KvRouter { ...@@ -962,12 +873,9 @@ impl KvRouter {
None None
}; };
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = if let Some(obj) = block_mm_infos let block_mm_infos = block_mm_infos
{ .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?) .transpose()?;
} else {
None
};
let chooser = self.inner.chooser.clone(); let chooser = self.inner.chooser.clone();
let update_states = request_id.is_some(); let update_states = request_id.is_some();
......
...@@ -1299,37 +1299,6 @@ class KvbmRequest: ...@@ -1299,37 +1299,6 @@ class KvbmRequest:
def __init__(self, request_id: int, tokens: List[int], block_size: int) -> None: def __init__(self, request_id: int, tokens: List[int], block_size: int) -> None:
... ...
class ZmqKvEventListener:
"""
A ZMQ-based key-value cache event listener that operates independently
of the dynamo runtime or event plane infrastructure.
"""
def __init__(
self, zmq_endpoint: str, zmq_topic: str, kv_block_size: int
) -> None:
"""
Create a new ZmqKvEventListener instance.
Args:
zmq_endpoint: ZeroMQ endpoint to connect to (e.g., "tcp://127.0.0.1:5557")
zmq_topic: ZeroMQ topic to subscribe to
kv_block_size: Size of KV cache blocks
"""
...
async def get_events(self) -> List[str]:
"""
Get all available KV cache events from the ZMQ listener.
Returns:
List of JSON-serialized KV cache events as strings
Raises:
ValueError: If events cannot be serialized to JSON
"""
...
class KvRouter: class KvRouter:
""" """
A KV-aware router that performs intelligent routing based on KV cache overlap. A KV-aware router that performs intelligent routing based on KV cache overlap.
......
...@@ -26,7 +26,6 @@ from dynamo._core import RadixTree as RadixTree ...@@ -26,7 +26,6 @@ from dynamo._core import RadixTree as RadixTree
from dynamo._core import RouterConfig as RouterConfig from dynamo._core import RouterConfig as RouterConfig
from dynamo._core import RouterMode as RouterMode from dynamo._core import RouterMode as RouterMode
from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher
from dynamo._core import ZmqKvEventListener as ZmqKvEventListener
from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq
from dynamo._core import fetch_model as fetch_model from dynamo._core import fetch_model as fetch_model
from dynamo._core import lora_name_to_id as lora_name_to_id from dynamo._core import lora_name_to_id as lora_name_to_id
......
...@@ -103,19 +103,6 @@ pub struct PrefillRouter { ...@@ -103,19 +103,6 @@ pub struct PrefillRouter {
} }
impl PrefillRouter { impl PrefillRouter {
fn routing_inputs(req: &PreprocessedRequest) -> (&[u32], Option<&[Option<BlockExtraInfo>]>) {
if let Some(mm_routing_info) = req.mm_routing_info.as_ref() {
let routing_tokens = mm_routing_info.routing_token_ids.as_slice();
if !routing_tokens.is_empty() {
return (
routing_tokens,
Some(mm_routing_info.block_mm_infos.as_slice()),
);
}
}
(&req.token_ids, None)
}
/// Create a disabled prefill router that will never activate (passthrough only) /// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled( pub fn disabled(
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
...@@ -298,7 +285,7 @@ impl PrefillRouter { ...@@ -298,7 +285,7 @@ impl PrefillRouter {
.as_ref() .as_ref()
.and_then(|r| r.priority_jump) .and_then(|r| r.priority_jump)
.unwrap_or(0.0); .unwrap_or(0.0);
let (routing_token_ids, block_mm_infos) = Self::routing_inputs(req); let (routing_token_ids, block_mm_infos) = req.block_mm_routing_info();
match self match self
.query_prefill_worker( .query_prefill_worker(
routing_token_ids, routing_token_ids,
......
...@@ -19,7 +19,7 @@ use crate::{ ...@@ -19,7 +19,7 @@ use crate::{
kv_router::{ kv_router::{
KvRouter, KvRouter,
metrics::RouterRequestMetrics, metrics::RouterRequestMetrics,
protocols::{BlockExtraInfo, TokensWithHashes, WorkerWithDpRank}, protocols::{TokensWithHashes, WorkerWithDpRank},
}, },
preprocessor::PreprocessedRequest, preprocessor::PreprocessedRequest,
protocols::common::{ protocols::common::{
...@@ -182,21 +182,6 @@ impl KvPushRouter { ...@@ -182,21 +182,6 @@ impl KvPushRouter {
KvPushRouter { inner, chooser } KvPushRouter { inner, chooser }
} }
fn routing_inputs(
request: &PreprocessedRequest,
) -> (&[u32], Option<&[Option<BlockExtraInfo>]>) {
if let Some(mm_routing_info) = request.mm_routing_info.as_ref() {
let routing_tokens = mm_routing_info.routing_token_ids.as_slice();
if !routing_tokens.is_empty() {
return (
routing_tokens,
Some(mm_routing_info.block_mm_infos.as_slice()),
);
}
}
(&request.token_ids, None)
}
/// Select a worker for the request, either using a preselected worker or finding the best match. /// Select a worker for the request, either using a preselected worker or finding the best match.
/// ///
/// When `is_query_only` is false, this also registers the request with the scheduler via `add_request`. /// When `is_query_only` is false, this also registers the request with the scheduler via `add_request`.
...@@ -212,7 +197,7 @@ impl KvPushRouter { ...@@ -212,7 +197,7 @@ impl KvPushRouter {
let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0); let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0);
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0); let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens); let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
let (routing_token_ids, block_mm_infos) = Self::routing_inputs(request); let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
// Get pre-selected worker based on phase, with backend_instance_id as fallback // Get pre-selected worker based on phase, with backend_instance_id as fallback
let preselected_id = match phase { let preselected_id = match phase {
...@@ -387,7 +372,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -387,7 +372,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let request_metrics = let request_metrics =
RouterRequestMetrics::from_component(self.chooser.client().endpoint.component()); RouterRequestMetrics::from_component(self.chooser.client().endpoint.component());
if let Some(ref tracker) = request.tracker { if let Some(ref tracker) = request.tracker {
let (routing_token_ids, _) = Self::routing_inputs(&request); let (routing_token_ids, _) = request.block_mm_routing_info();
let isl_blocks = routing_token_ids.len().div_ceil(block_size); let isl_blocks = routing_token_ids.len().div_ceil(block_size);
tracker.record_kv_hit(overlap_amount, isl_blocks); tracker.record_kv_hit(overlap_amount, isl_blocks);
tracker.record_isl( tracker.record_isl(
......
...@@ -205,6 +205,19 @@ impl PreprocessedRequest { ...@@ -205,6 +205,19 @@ impl PreprocessedRequest {
pub fn routing_mut(&mut self) -> &mut RoutingHints { pub fn routing_mut(&mut self) -> &mut RoutingHints {
self.routing.get_or_insert_with(RoutingHints::default) self.routing.get_or_insert_with(RoutingHints::default)
} }
/// Extract the token IDs and optional block MM info used for KV cache overlap computation.
/// Falls back to the request's primary `token_ids` when no multimodal routing info is present.
pub fn block_mm_routing_info(&self) -> (&[TokenIdType], Option<&[Option<BlockExtraInfo>]>) {
let Some(mm) = self.mm_routing_info.as_ref() else {
return (&self.token_ids, None);
};
let tokens = mm.routing_token_ids.as_slice();
if tokens.is_empty() {
return (&self.token_ids, None);
}
(tokens, Some(mm.block_mm_infos.as_slice()))
}
} }
/// [`PreprocessedEmbeddingRequest`] is the internal representation of an embedding request /// [`PreprocessedEmbeddingRequest`] is the internal representation of an embedding request
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment