Unverified Commit fc229004 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: Remove ZmqKvEventListener binding and rework standalone TRT-LLM example...


chore: Remove ZmqKvEventListener binding and rework standalone TRT-LLM example to use native Python ZMQ (#6164)
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent c7986b35
......@@ -11,10 +11,11 @@ from contextlib import asynccontextmanager
import numpy as np
import uvicorn
import zmq
import zmq.asyncio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, ValidationError
from dynamo._core import RadixTree, ZmqKvEventListener
from dynamo._core import RadixTree
logger = logging.getLogger(__name__)
......@@ -110,16 +111,23 @@ class KvRouter:
)
for i in range(num_workers)
]
self.async_context = zmq.asyncio.Context()
self.kv_listeners = [
ZmqKvEventListener(
f"tcp://localhost:{base_kv_events_port + i}", "", block_size
)
self._create_kv_listener(base_kv_events_port + i)
for i in range(num_workers)
]
self.background_tasks: list[asyncio.Task] = []
logger.info("Router initialized")
def _create_kv_listener(self, port: int) -> zmq.asyncio.Socket:
"""Create an async ZMQ SUB socket for receiving KV cache events."""
sock = self.async_context.socket(zmq.SUB)
sock.connect(f"tcp://localhost:{port}")
sock.setsockopt(zmq.SUBSCRIBE, b"")
sock.setsockopt(zmq.RCVTIMEO, 1)
return sock
# -------------------------------------------------------------------------
# Background Tasks
# -------------------------------------------------------------------------
......@@ -153,11 +161,11 @@ class KvRouter:
async def _poll_worker_kv_events(self, worker_id: int):
"""Poll KV events for a single worker and update RadixTree."""
sock = self.kv_listeners[worker_id]
while True:
try:
events: list[str] = await self.kv_listeners[worker_id].get_events()
for event_str in events:
event = json.loads(event_str)
event_bytes = await sock.recv(zmq.NOBLOCK)
event = json.loads(event_bytes)
dump_kv_event(worker_id, event)
self.radix_tree.apply_event(
worker_id, json.dumps(event).encode("utf-8")
......@@ -250,8 +258,11 @@ class KvRouter:
for listener in self.load_listeners:
listener.close()
for listener in self.kv_listeners:
listener.close()
self.context.term()
self.async_context.term()
logger.info("KvRouter shutdown completed")
......
......@@ -10,15 +10,16 @@ if "PYTHONHASHSEED" not in os.environ:
os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
import asyncio
import json
import logging
import time
from typing import AsyncGenerator, Optional
import msgpack
import zmq
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig
from dynamo.llm import compute_block_hash_for_seq_py
logger = logging.getLogger(__name__)
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
......@@ -87,7 +88,7 @@ class MetricsPublisher:
class KvEventsPublisher:
"""Publishes KV cache events over ZMQ."""
"""Publishes KV cache events as KvCacheEvent JSON over ZMQ."""
def __init__(self, port: int, block_size: int):
self.context = zmq.Context()
......@@ -95,7 +96,7 @@ class KvEventsPublisher:
self.socket.bind(f"tcp://*:{port}")
self.block_size = block_size
self.partial_block_hashes: set[int] = set()
self.sequence_number = 0
self.next_event_id = 0
def publish_stored(
self,
......@@ -104,34 +105,46 @@ class KvEventsPublisher:
parent_hash: int | None,
block_mm_infos: list[dict | None] | None,
):
"""Publish a BlockStored event.
"""Publish a KvCacheEvent with stored blocks.
Args:
block_hashes: List of block hashes being stored.
token_ids: All token IDs across the blocks.
parent_hash: Hash of the parent block (if any).
block_mm_infos: Per-block multimodal info list. Each element corresponds
to a block and is either the mm_info dict (for blocks containing
image tokens) or None (for text-only blocks).
Computes tokens_hash per block using compute_block_hash_for_seq_py
(including MM info when present) and publishes as KvCacheEvent JSON.
"""
event = {
"type": "BlockStored",
"block_hashes": [to_unsigned_u64(h) for h in block_hashes],
"token_ids": token_ids,
"block_size": self.block_size,
}
if parent_hash is not None:
event["parent_block_hash"] = to_unsigned_u64(parent_hash)
# Compute tokens_hash per block (MM-aware when block_mm_infos provided)
tokens_hashes = compute_block_hash_for_seq_py(
token_ids, self.block_size, block_mm_infos
)
if block_mm_infos is not None:
event["block_mm_infos"] = block_mm_infos
blocks = []
for i, ext_hash in enumerate(block_hashes):
block_data = {
"block_hash": to_unsigned_u64(ext_hash),
"tokens_hash": tokens_hashes[i],
}
mm_info = block_mm_infos[i] if block_mm_infos else None
if mm_info is not None:
block_data["mm_extra_info"] = mm_info
blocks.append(block_data)
self._send([event])
event = {
"event_id": self.next_event_id,
"data": {
"stored": {
"parent_hash": (
to_unsigned_u64(parent_hash)
if parent_hash is not None
else None
),
"blocks": blocks,
}
},
"dp_rank": 0,
}
self.next_event_id += 1
self._send(event)
def publish_removed(self, block_hashes: list[int]):
"""Publish a BlockRemoved event."""
# Filter out partial blocks
"""Publish a KvCacheEvent with removed blocks."""
filtered = []
for h in block_hashes:
if h in self.partial_block_hashes:
......@@ -139,21 +152,29 @@ class KvEventsPublisher:
else:
filtered.append(to_unsigned_u64(h))
if filtered:
self._send([{"type": "BlockRemoved", "block_hashes": filtered}])
if not filtered:
return
def _send(self, events: list[dict]):
"""Send events via ZMQ multipart message."""
batch = [time.time(), events, 0]
event = {
"event_id": self.next_event_id,
"data": {
"removed": {
"block_hashes": filtered,
}
},
"dp_rank": 0,
}
self.next_event_id += 1
self._send(event)
def _send(self, event: dict):
"""Send a single KvCacheEvent as JSON over ZMQ."""
try:
payload = msgpack.packb(batch, use_bin_type=True)
payload = json.dumps(event).encode("utf-8")
except Exception as e:
logger.error(f"msgpack error: {e}")
logger.error(f"JSON encode error: {e}")
return
seq_bytes = self.sequence_number.to_bytes(8, byteorder="big")
self.sequence_number += 1
self.socket.send_multipart([b"", seq_bytes, payload])
self.socket.send(payload)
def close(self):
self.socket.close()
......
......@@ -169,7 +169,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::OverlapScores>()?;
m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<llm::kv::RadixTree>()?;
m.add_class::<llm::kv::ZmqKvEventListener>()?;
m.add_class::<llm::lora::LoRADownloader>()?;
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpAsyncEngine>()?;
......
......@@ -17,11 +17,15 @@ use tracing;
use llm_rs::kv_router::KvPushRouter as RsKvPushRouter;
use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener};
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks};
use llm_rs::protocols::common::timing::RequestTracker;
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use serde_json::json;
fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> {
depythonize(obj).map_err(to_pyerr)
}
#[pyfunction]
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None))]
pub fn compute_block_hash_for_seq_py(
......@@ -36,21 +40,12 @@ pub fn compute_block_hash_for_seq_py(
));
}
// Convert Python block_mm_infos to Rust Vec<Option<BlockExtraInfo>>
let mm_infos_rust: Option<Vec<Option<BlockExtraInfo>>> = block_mm_infos
let mm_infos = block_mm_infos
.as_ref()
.map(|infos_py| {
depythonize::<Vec<Option<BlockExtraInfo>>>(infos_py).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to convert block_mm_infos: {}",
e
))
})
})
.map(depythonize_block_mm_infos)
.transpose()?;
let hashes =
compute_block_hash_for_seq(&tokens, kv_block_size as u32, mm_infos_rust.as_deref());
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32, mm_infos.as_deref());
Ok(hashes.into_iter().map(|h| h.0).collect())
}
......@@ -101,79 +96,6 @@ impl WorkerMetricsPublisher {
}
}
/// A ZMQ-based key-value cache event listener that operates independently
/// of the dynamo runtime or event plane infrastructure.
#[pyclass]
pub(crate) struct ZmqKvEventListener {
event_receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<KvCacheEvent>>>,
shutdown_token: tokio_util::sync::CancellationToken,
}
#[pymethods]
impl ZmqKvEventListener {
#[new]
#[pyo3(signature = (zmq_endpoint, zmq_topic, kv_block_size))]
fn new(zmq_endpoint: String, zmq_topic: String, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<KvCacheEvent>();
let shutdown_token = tokio_util::sync::CancellationToken::new();
// Standalone listener needs its own event ID counter
let next_event_id = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
tokio::spawn(start_zmq_listener(
zmq_endpoint,
zmq_topic,
tx,
shutdown_token.clone(),
kv_block_size as u32,
next_event_id,
));
Ok(Self {
event_receiver: Arc::new(tokio::sync::Mutex::new(rx)),
shutdown_token,
})
})
}
fn get_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let receiver = self.event_receiver.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut rx = receiver.lock().await;
let mut events = Vec::new();
// Drain all available events
while let Ok(event) = rx.try_recv() {
events.push(event);
}
// Convert events to JSON strings
let json_events: Result<Vec<String>, _> =
events.iter().map(serde_json::to_string).collect();
match json_events {
Ok(json_strings) => Ok(json_strings),
Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to serialize events to JSON: {}",
e
))),
}
})
}
}
// manual shutdown needed as it's not tied to the dynamo DRT
impl Drop for ZmqKvEventListener {
fn drop(&mut self) {
self.shutdown_token.cancel();
}
}
#[pyclass]
pub(crate) struct KvEventPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
......@@ -243,17 +165,9 @@ impl KvEventPublisher {
// Use shared monotonic event_id counter from the inner publisher
let event_id = inner.next_event_id();
// Convert Python block_mm_infos to Rust Vec<Option<BlockExtraInfo>>
let mm_infos_rust: Option<Vec<Option<BlockExtraInfo>>> = block_mm_infos
let mm_infos = block_mm_infos
.as_ref()
.map(|infos_py| {
depythonize::<Vec<Option<BlockExtraInfo>>>(infos_py).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to convert block_mm_infos: {}",
e
))
})
})
.map(depythonize_block_mm_infos)
.transpose()?;
py.allow_threads(|| {
......@@ -269,7 +183,7 @@ impl KvEventPublisher {
&block_hashes_u64,
lora_id,
&warning_count,
mm_infos_rust.as_deref(),
mm_infos.as_deref(),
),
}),
dp_rank,
......@@ -862,12 +776,9 @@ impl KvRouter {
None
};
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = if let Some(obj) = block_mm_infos
{
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
let block_mm_infos = block_mm_infos
.map(|obj| depythonize_block_mm_infos(obj.bind(py)))
.transpose()?;
let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> =
if let Some(obj) = multi_modal_data {
......@@ -962,12 +873,9 @@ impl KvRouter {
None
};
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = if let Some(obj) = block_mm_infos
{
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
let block_mm_infos = block_mm_infos
.map(|obj| depythonize_block_mm_infos(obj.bind(py)))
.transpose()?;
let chooser = self.inner.chooser.clone();
let update_states = request_id.is_some();
......
......@@ -1299,37 +1299,6 @@ class KvbmRequest:
def __init__(self, request_id: int, tokens: List[int], block_size: int) -> None:
...
class ZmqKvEventListener:
"""
A ZMQ-based key-value cache event listener that operates independently
of the dynamo runtime or event plane infrastructure.
"""
def __init__(
self, zmq_endpoint: str, zmq_topic: str, kv_block_size: int
) -> None:
"""
Create a new ZmqKvEventListener instance.
Args:
zmq_endpoint: ZeroMQ endpoint to connect to (e.g., "tcp://127.0.0.1:5557")
zmq_topic: ZeroMQ topic to subscribe to
kv_block_size: Size of KV cache blocks
"""
...
async def get_events(self) -> List[str]:
"""
Get all available KV cache events from the ZMQ listener.
Returns:
List of JSON-serialized KV cache events as strings
Raises:
ValueError: If events cannot be serialized to JSON
"""
...
class KvRouter:
"""
A KV-aware router that performs intelligent routing based on KV cache overlap.
......
......@@ -26,7 +26,6 @@ from dynamo._core import RadixTree as RadixTree
from dynamo._core import RouterConfig as RouterConfig
from dynamo._core import RouterMode as RouterMode
from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher
from dynamo._core import ZmqKvEventListener as ZmqKvEventListener
from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq
from dynamo._core import fetch_model as fetch_model
from dynamo._core import lora_name_to_id as lora_name_to_id
......
......@@ -103,19 +103,6 @@ pub struct PrefillRouter {
}
impl PrefillRouter {
fn routing_inputs(req: &PreprocessedRequest) -> (&[u32], Option<&[Option<BlockExtraInfo>]>) {
if let Some(mm_routing_info) = req.mm_routing_info.as_ref() {
let routing_tokens = mm_routing_info.routing_token_ids.as_slice();
if !routing_tokens.is_empty() {
return (
routing_tokens,
Some(mm_routing_info.block_mm_infos.as_slice()),
);
}
}
(&req.token_ids, None)
}
/// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled(
model_manager: Arc<ModelManager>,
......@@ -298,7 +285,7 @@ impl PrefillRouter {
.as_ref()
.and_then(|r| r.priority_jump)
.unwrap_or(0.0);
let (routing_token_ids, block_mm_infos) = Self::routing_inputs(req);
let (routing_token_ids, block_mm_infos) = req.block_mm_routing_info();
match self
.query_prefill_worker(
routing_token_ids,
......
......@@ -19,7 +19,7 @@ use crate::{
kv_router::{
KvRouter,
metrics::RouterRequestMetrics,
protocols::{BlockExtraInfo, TokensWithHashes, WorkerWithDpRank},
protocols::{TokensWithHashes, WorkerWithDpRank},
},
preprocessor::PreprocessedRequest,
protocols::common::{
......@@ -182,21 +182,6 @@ impl KvPushRouter {
KvPushRouter { inner, chooser }
}
fn routing_inputs(
request: &PreprocessedRequest,
) -> (&[u32], Option<&[Option<BlockExtraInfo>]>) {
if let Some(mm_routing_info) = request.mm_routing_info.as_ref() {
let routing_tokens = mm_routing_info.routing_token_ids.as_slice();
if !routing_tokens.is_empty() {
return (
routing_tokens,
Some(mm_routing_info.block_mm_infos.as_slice()),
);
}
}
(&request.token_ids, None)
}
/// Select a worker for the request, either using a preselected worker or finding the best match.
///
/// When `is_query_only` is false, this also registers the request with the scheduler via `add_request`.
......@@ -212,7 +197,7 @@ impl KvPushRouter {
let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0);
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
let (routing_token_ids, block_mm_infos) = Self::routing_inputs(request);
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
// Get pre-selected worker based on phase, with backend_instance_id as fallback
let preselected_id = match phase {
......@@ -387,7 +372,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let request_metrics =
RouterRequestMetrics::from_component(self.chooser.client().endpoint.component());
if let Some(ref tracker) = request.tracker {
let (routing_token_ids, _) = Self::routing_inputs(&request);
let (routing_token_ids, _) = request.block_mm_routing_info();
let isl_blocks = routing_token_ids.len().div_ceil(block_size);
tracker.record_kv_hit(overlap_amount, isl_blocks);
tracker.record_isl(
......
......@@ -205,6 +205,19 @@ impl PreprocessedRequest {
pub fn routing_mut(&mut self) -> &mut RoutingHints {
self.routing.get_or_insert_with(RoutingHints::default)
}
/// Extract the token IDs and optional block MM info used for KV cache overlap computation.
/// Falls back to the request's primary `token_ids` when no multimodal routing info is present.
pub fn block_mm_routing_info(&self) -> (&[TokenIdType], Option<&[Option<BlockExtraInfo>]>) {
let Some(mm) = self.mm_routing_info.as_ref() else {
return (&self.token_ids, None);
};
let tokens = mm.routing_token_ids.as_slice();
if tokens.is_empty() {
return (&self.token_ids, None);
}
(tokens, Some(mm.block_mm_infos.as_slice()))
}
}
/// [`PreprocessedEmbeddingRequest`] is the internal representation of an embedding request
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment