Commit 4b6cfc1b authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: KV recorder for dumping router events into a jsonl (#505)

parent 4c7dceca
......@@ -72,6 +72,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::AggregatedMetrics>()?;
m.add_class::<llm::kv::KvMetricsAggregator>()?;
m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<llm::kv::KvRecorder>()?;
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?;
......
......@@ -392,3 +392,122 @@ impl KvMetricsAggregator {
})
}
}
#[pyclass]
pub(crate) struct KvRecorder {
inner: Arc<llm_rs::kv_router::recorder::KvRecorder>,
}
#[pymethods]
impl KvRecorder {
#[new]
#[pyo3(signature = (component, output_path=None, max_lines_per_file=None, max_count=None, max_time=None))]
fn new(
component: Component,
output_path: Option<String>,
max_lines_per_file: Option<usize>,
max_count: Option<usize>,
max_time: Option<f64>,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let token = component.inner.drt().runtime().child_token();
// Create a temp path if none provided
let path = match output_path {
Some(p) => p,
None => {
let temp_dir = std::env::temp_dir();
temp_dir
.join("kv_events.jsonl")
.to_string_lossy()
.to_string()
}
};
let inner = llm_rs::kv_router::recorder::KvRecorder::new(
token.clone(),
path,
max_lines_per_file,
max_count,
max_time,
)
.await
.map_err(to_pyerr)?;
// Subscribe to KV events
let mut kv_events_rx = component
.inner
.subscribe(llm_rs::kv_router::KV_EVENT_SUBJECT)
.await
.map_err(to_pyerr)?;
let event_tx = inner.event_sender();
// Spawn a task to forward events to the recorder
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: llm_rs::kv_router::indexer::RouterEvent =
serde_json::from_slice(&event.payload).unwrap();
tracing::debug!("KvRecorder received kv event: {:?}", event);
if let Err(e) = event_tx.send(event).await {
tracing::trace!(
"KvRecorder failed to send kv event; shutting down: {:?}",
e
);
}
}
});
Ok(Self {
inner: Arc::new(inner),
})
})
}
fn event_count<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let recorder = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let count = recorder.event_count().await;
Ok(count)
})
}
fn elapsed_time<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let recorder = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
match recorder.elapsed_time().await {
Ok(elapsed) => Ok(elapsed.as_secs_f64()),
Err(_) => Ok(0.0), // Return 0.0 when no events have been received yet
}
})
}
#[pyo3(signature = (indexer, timed=false, max_count=None, max_time=None))]
fn replay_events<'py>(
&self,
py: Python<'py>,
indexer: &KvIndexer,
timed: bool,
max_count: Option<usize>,
max_time: Option<f64>,
) -> PyResult<Bound<'py, PyAny>> {
let event_tx = indexer.inner.event_sender();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let count = llm_rs::kv_router::recorder::KvRecorder::send_events(
"dummy_path", // This doesn't matter as we'll use the provided event_tx
&event_tx,
timed,
max_count,
max_time,
)
.await
.map_err(to_pyerr)?;
Ok(count)
})
}
fn shutdown(&self) -> PyResult<()> {
self.inner.shutdown();
Ok(())
}
}
......@@ -108,7 +108,6 @@ class Component:
"""
...
class Endpoint:
"""
An Endpoint is a single API endpoint
......@@ -330,6 +329,79 @@ class KvIndexer:
"""
...
class KvRecorder:
"""
A recorder for KV Router events.
"""
...
def __init__(
self,
component: Component,
output_path: Optional[str] = None,
max_lines_per_file: Optional[int] = None,
max_count: Optional[int] = None,
max_time: Optional[float] = None,
) -> None:
"""
Create a new KvRecorder instance.
Args:
component: The component to associate with this recorder
output_path: Path to the JSONL file to write events to
max_lines_per_file: Maximum number of lines per file before rotating to a new file
max_count: Maximum number of events to record before shutting down
max_time: Maximum duration in seconds to record before shutting down
"""
...
def event_count(self) -> int:
"""
Get the count of recorded events.
Returns:
The number of events recorded
"""
...
def elapsed_time(self) -> float:
"""
Get the elapsed time since the recorder was started.
Returns:
The elapsed time in seconds as a float
"""
...
def replay_events(
self,
indexer: KvIndexer,
timed: bool = False,
max_count: Optional[int] = None,
max_time: Optional[float] = None,
) -> int:
"""
Populate an indexer with the recorded events.
Args:
indexer: The KvIndexer to populate with events
timed: If true, events will be sent according to their recorded timestamps.
If false, events will be sent without any delay in between.
max_count: Maximum number of events to send before stopping
max_time: Maximum duration in seconds to send events before stopping
Returns:
The number of events sent to the indexer
"""
...
def shutdown(self) -> None:
"""
Shutdown the recorder.
"""
...
class AggregatedMetrics:
"""
A collection of metrics of the endpoints
......@@ -362,12 +434,23 @@ class KvEventPublisher:
...
def __init__(self, component: Component, worker_id: int, kv_block_size: int) -> None:
def __init__(
self, component: Component, worker_id: int, kv_block_size: int
) -> None:
"""
Create a `KvEventPublisher` object
"""
def publish_stored(self, event_id, int, token_ids: List[int], num_block_tokens: List[int], block_hashes: List[int], lora_id: int, parent_hash: Optional[int] = None) -> None:
def publish_stored(
self,
event_id,
int,
token_ids: List[int],
num_block_tokens: List[int],
block_hashes: List[int],
lora_id: int,
parent_hash: Optional[int] = None,
) -> None:
"""
Publish a KV stored event.
"""
......
......@@ -22,5 +22,6 @@ from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvMetricsPublisher as KvMetricsPublisher
from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouter as KvRouter
from dynamo._core import OverlapScores as OverlapScores
......@@ -30,6 +30,7 @@ pub mod indexer;
pub mod metrics_aggregator;
pub mod protocols;
pub mod publisher;
pub mod recorder;
pub mod scheduler;
pub mod scoring;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::kv_router::indexer::RouterEvent;
use crate::recorder::Recorder;
// Type alias for backward compatibility
pub type KvRecorder = Recorder<RouterEvent>;
#[cfg(test)]
mod tests {
use super::*;
use crate::kv_router::indexer::KvIndexer;
use crate::kv_router::indexer::WorkerId;
use crate::kv_router::protocols::*;
use std::time::Duration;
use tempfile::tempdir;
use tokio::fs;
use tokio_util::sync::CancellationToken;
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
})
.collect()
}
fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
},
)
}
fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
},
)
}
#[tokio::test]
async fn test_recorder_streams_events_to_file() {
// Create a temporary directory for output files
let dir = tempdir().unwrap();
let file_path = dir.path().join("kv_events.jsonl");
// Part 1: Record events to a file
let token = CancellationToken::new();
let recorder = KvRecorder::new(token.clone(), &file_path, None, None, None)
.await
.unwrap();
let event_tx = recorder.event_sender();
// Create first event from worker 1 using helper function
let event1 = create_store_event(1, 42, vec![1, 2, 3], None);
// Create second event from worker 2 using helper function
let event2 = create_remove_event(1, 43, vec![2, 3]);
// Send both events one after another
event_tx.send(event1).await.unwrap();
event_tx.send(event2).await.unwrap();
// Allow some time for processing
tokio::time::sleep(Duration::from_millis(10)).await;
// Check that both events were recorded
assert_eq!(recorder.event_count().await, 2);
// Force shutdown to flush file
recorder.shutdown();
tokio::time::sleep(Duration::from_millis(10)).await;
// Read the file and verify content
let content = fs::read_to_string(&file_path).await.unwrap();
let lines: Vec<&str> = content.lines().collect();
// Print the content of the JSONL file
println!("JSONL file content:");
for (i, line) in lines.iter().enumerate() {
println!("Line {}: {}", i + 1, line);
}
assert_eq!(lines.len(), 2, "Expected 2 lines in the file");
// Part 2: Now create a KvIndexer and load the events from the file
let indexer_token = CancellationToken::new();
let kv_block_size = 32; // Default block size for testing
let indexer = KvIndexer::new(indexer_token.clone(), kv_block_size);
let indexer_event_tx = indexer.event_sender();
// Use the send_events method to load events from file to indexer
let count = KvRecorder::send_events(&file_path, &indexer_event_tx, false, None, None)
.await
.unwrap();
assert_eq!(count, 2, "Expected to send 2 events from file to indexer");
}
}
......@@ -29,6 +29,7 @@ pub mod model_card;
pub mod model_type;
pub mod preprocessor;
pub mod protocols;
pub mod recorder;
pub mod tokenizers;
pub mod tokens;
pub mod types;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment