Unverified Commit 681951d4 authored by Michael Feil's avatar Michael Feil Committed by GitHub
Browse files

feat: python gil release for radix tree + dump_tree_as_events in python (#3748)


Signed-off-by: default avatarmichaelfeil <me@michaelfeil.eu>
Co-authored-by: default avatarYan Ru Pei <yanrpei@gmail.com>
parent e6fc0e29
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
use pythonize::{depythonize, pythonize}; use pythonize::{depythonize, pythonize};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicU32; use std::sync::atomic::AtomicU32;
use std::sync::mpsc;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use super::*; use super::*;
...@@ -335,10 +337,36 @@ impl OverlapScores { ...@@ -335,10 +337,36 @@ impl OverlapScores {
} }
} }
// NOTE: the user needs to guarantee that this stays single threaded in Python land #[derive(Debug)]
#[pyclass(unsendable)] enum RadixTreeRequest {
FindMatches {
local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash>,
early_exit: bool,
response_tx: mpsc::SyncSender<llm_rs::kv_router::indexer::OverlapScores>,
},
ApplyEvent {
worker_id: WorkerId,
kv_cache_event_bytes: Vec<u8>,
response_tx: mpsc::SyncSender<PyResult<()>>,
},
RemoveWorker {
worker_id: WorkerId,
response_tx: mpsc::SyncSender<()>,
},
ClearAllBlocks {
worker_id: WorkerId,
response_tx: mpsc::SyncSender<()>,
},
DumpTreeAsEvents {
response_tx: mpsc::SyncSender<Vec<llm_rs::kv_router::indexer::RouterEvent>>,
},
Shutdown,
}
// NOTE: RadixTree is now thread-safe with pure sync patterns
#[pyclass]
pub(crate) struct RadixTree { pub(crate) struct RadixTree {
inner: llm_rs::kv_router::indexer::RadixTree, request_tx: mpsc::Sender<RadixTreeRequest>,
} }
#[pymethods] #[pymethods]
...@@ -347,55 +375,249 @@ impl RadixTree { ...@@ -347,55 +375,249 @@ impl RadixTree {
#[pyo3(signature = (expiration_duration_secs=None))] #[pyo3(signature = (expiration_duration_secs=None))]
fn new(expiration_duration_secs: Option<f64>) -> PyResult<Self> { fn new(expiration_duration_secs: Option<f64>) -> PyResult<Self> {
let expiration_duration = expiration_duration_secs.map(std::time::Duration::from_secs_f64); let expiration_duration = expiration_duration_secs.map(std::time::Duration::from_secs_f64);
let inner = llm_rs::kv_router::indexer::RadixTree::new_with_frequency(expiration_duration);
Ok(Self { inner }) let (request_tx, request_rx) = mpsc::channel::<RadixTreeRequest>();
// Spawn dedicated thread with simplified sync processing
std::thread::spawn(move || {
let mut radix_tree =
llm_rs::kv_router::indexer::RadixTree::new_with_frequency(expiration_duration);
loop {
match request_rx.recv() {
Ok(RadixTreeRequest::Shutdown) => {
tracing::debug!("RadixTree thread received shutdown request");
break;
}
Ok(request) => {
Self::handle_request(&mut radix_tree, request);
}
Err(mpsc::RecvError) => {
tracing::debug!("RadixTree request channel disconnected");
break;
}
}
}
});
Ok(Self { request_tx })
} }
#[pyo3(signature = (sequence, early_exit=false))] #[pyo3(signature = (sequence, early_exit=false))]
fn find_matches( fn find_matches(
&self, &self,
_py: Python, py: Python,
sequence: Vec<u64>, sequence: Vec<u64>,
early_exit: bool, early_exit: bool,
) -> PyResult<OverlapScores> { ) -> PyResult<OverlapScores> {
let local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash> = sequence let (response_tx, response_rx) = mpsc::sync_channel(1);
.into_iter()
.map(llm_rs::kv_router::protocols::LocalBlockHash)
.collect();
let rs_overlap_scores = self.inner.find_matches(local_block_hashes, early_exit); let local_block_hashes = py.allow_threads(|| {
Ok(OverlapScores { sequence
inner: rs_overlap_scores, .into_iter()
}) .map(llm_rs::kv_router::protocols::LocalBlockHash)
.collect()
});
let request = RadixTreeRequest::FindMatches {
local_block_hashes,
early_exit,
response_tx,
};
self.request_tx.send(request).map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"RadixTree background task has shut down",
)
})?;
// Release GIL while waiting for response
let result = py.allow_threads(move || {
response_rx.recv().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
})
})?;
Ok(OverlapScores { inner: result })
} }
fn apply_event( fn apply_event(
&mut self, &self,
_py: Python, py: Python,
worker_id: WorkerId, worker_id: WorkerId,
kv_cache_event_bytes: &[u8], kv_cache_event_bytes: &[u8],
) -> PyResult<()> { ) -> PyResult<()> {
let kv_cache_event: llm_rs::kv_router::protocols::KvCacheEvent = let (response_tx, response_rx) = mpsc::sync_channel(1);
serde_json::from_slice(kv_cache_event_bytes).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to deserialize KvCacheEvent: {}",
e
))
})?;
let router_event = llm_rs::kv_router::indexer::RouterEvent::new(worker_id, kv_cache_event); let request = RadixTreeRequest::ApplyEvent {
let _ = self.inner.apply_event(router_event); worker_id,
Ok(()) kv_cache_event_bytes: kv_cache_event_bytes.to_vec(),
response_tx,
};
self.request_tx.send(request).map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"RadixTree background task has shut down",
)
})?;
// Release GIL while waiting for response
let result = py.allow_threads(move || response_rx.recv());
result.map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
})?
} }
fn remove_worker(&mut self, _py: Python, worker_id: WorkerId) -> PyResult<()> { fn remove_worker(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
self.inner.remove_worker(worker_id); let (response_tx, response_rx) = mpsc::sync_channel(1);
Ok(())
let request = RadixTreeRequest::RemoveWorker {
worker_id,
response_tx,
};
self.request_tx.send(request).map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"RadixTree background task has shut down",
)
})?;
// Release GIL while waiting for response
py.allow_threads(move || {
response_rx.recv().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
})
})
} }
fn clear_all_blocks(&mut self, _py: Python, worker_id: WorkerId) -> PyResult<()> { fn clear_all_blocks(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
self.inner.clear_all_blocks(worker_id); let (response_tx, response_rx) = mpsc::sync_channel(1);
Ok(())
let request = RadixTreeRequest::ClearAllBlocks {
worker_id,
response_tx,
};
self.request_tx.send(request).map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"RadixTree background task has shut down",
)
})?;
// Release GIL while waiting for response
py.allow_threads(move || {
response_rx.recv().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
})
})
}
fn dump_tree_as_events(&self, py: Python) -> PyResult<Vec<String>> {
let (response_tx, response_rx) = mpsc::sync_channel(1);
let request = RadixTreeRequest::DumpTreeAsEvents { response_tx };
self.request_tx.send(request).map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Failed to send dump tree request")
})?;
// Release GIL while waiting for response from dedicated thread
let events = py.allow_threads(move || {
response_rx.recv().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Failed to receive dump tree response",
)
})
})?;
// Serialize RouterEvent structs to JSON strings with GIL released
py.allow_threads(move || {
events
.into_iter()
.map(|event| {
serde_json::to_string(&event).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to serialize event to JSON: {}",
e
))
})
})
.collect::<Result<Vec<String>, PyErr>>()
})
}
}
impl RadixTree {
fn handle_request(
radix_tree: &mut llm_rs::kv_router::indexer::RadixTree,
request: RadixTreeRequest,
) {
match request {
RadixTreeRequest::FindMatches {
local_block_hashes,
early_exit,
response_tx,
} => {
let result = radix_tree.find_matches(local_block_hashes, early_exit);
let _ = response_tx.send(result);
}
RadixTreeRequest::ApplyEvent {
worker_id,
kv_cache_event_bytes,
response_tx,
} => {
let result = match serde_json::from_slice::<
llm_rs::kv_router::protocols::KvCacheEvent,
>(&kv_cache_event_bytes)
{
Ok(kv_cache_event) => {
let router_event =
llm_rs::kv_router::indexer::RouterEvent::new(worker_id, kv_cache_event);
match radix_tree.apply_event(router_event) {
Ok(_) => Ok(()),
Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
format!("Failed to apply event: {}", e),
)),
}
}
Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to deserialize KvCacheEvent: {}",
e
))),
};
let _ = response_tx.send(result);
}
RadixTreeRequest::RemoveWorker {
worker_id,
response_tx,
} => {
radix_tree.remove_worker(worker_id);
let _ = response_tx.send(());
}
RadixTreeRequest::ClearAllBlocks {
worker_id,
response_tx,
} => {
radix_tree.clear_all_blocks(worker_id);
let _ = response_tx.send(());
}
RadixTreeRequest::DumpTreeAsEvents { response_tx } => {
let events = radix_tree.dump_tree_as_events();
let _ = response_tx.send(events);
}
RadixTreeRequest::Shutdown => {
// This is handled in the main loop
}
}
}
}
// Cleanup when RadixTree is dropped
impl Drop for RadixTree {
fn drop(&mut self) {
// Only need graceful shutdown via RadixTreeRequest::Shutdown
let _ = self.request_tx.send(RadixTreeRequest::Shutdown);
} }
} }
......
...@@ -553,7 +553,8 @@ class RadixTree: ...@@ -553,7 +553,8 @@ class RadixTree:
""" """
A RadixTree that tracks KV cache blocks and can find prefix matches for sequences. A RadixTree that tracks KV cache blocks and can find prefix matches for sequences.
NOTE: This class is not thread-safe and should only be used from a single thread in Python. Thread-safe: operations route to a dedicated background thread and long calls
release the Python GIL.
""" """
def __init__(self, expiration_duration_secs: Optional[float] = None) -> None: def __init__(self, expiration_duration_secs: Optional[float] = None) -> None:
...@@ -612,6 +613,15 @@ class RadixTree: ...@@ -612,6 +613,15 @@ class RadixTree:
""" """
... ...
def dump_tree_as_events(self) -> List[str]:
"""
Dump the current RadixTree state as a list of JSON-serialized KV cache events.
Returns:
List of JSON-serialized KV cache events as strings
"""
...
class KvIndexer: class KvIndexer:
""" """
A KV Indexer that tracks KV Events emitted by workers. Events include add_block and remove_block. A KV Indexer that tracks KV Events emitted by workers. Events include add_block and remove_block.
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import asyncio import asyncio
import json
import threading
from typing import List from typing import List
import pytest import pytest
...@@ -88,11 +90,125 @@ async def test_radix_tree_binding(distributed_runtime): ...@@ -88,11 +90,125 @@ async def test_radix_tree_binding(distributed_runtime):
overlap_scores.scores[worker_key] == 1 overlap_scores.scores[worker_key] == 1
), f"Expected score 1 for worker {worker_key}, got {overlap_scores.scores[worker_key]}" ), f"Expected score 1 for worker {worker_key}, got {overlap_scores.scores[worker_key]}"
blocks = radix_tree.dump_tree_as_events()
assert len(blocks) == 1, f"Expected 1 block event, got {len(blocks)}"
json.loads(blocks[0]) # check valid json
# cleanup
radix_tree.remove_worker(worker_id)
blocks_empty = radix_tree.dump_tree_as_events()
assert (
len(blocks_empty) == 0
), f"Expected 0 block events after removal, got {len(blocks_empty)}"
print( print(
f"✓ RadixTree test passed: worker {worker_key} has score {overlap_scores.scores[worker_key]}" f"✓ RadixTree test passed: worker {worker_key} has score {overlap_scores.scores[worker_key]}"
) )
@pytest.mark.asyncio
@pytest.mark.forked
@pytest.mark.parametrize("num_threads", [2, 3, 5, 128])
@pytest.mark.parametrize("prepopulate_worker_ids", [True, False])
@pytest.mark.parametrize("expiration_duration_secs", [None])
@pytest.mark.parametrize("is_threaded", [True, False])
async def test_radix_tree_thread_safety(
distributed_runtime,
num_threads,
prepopulate_worker_ids,
expiration_duration_secs,
is_threaded,
):
"""Test RadixTree thread safety by applying events from multiple threads."""
radix_tree = RadixTree(expiration_duration_secs=expiration_duration_secs)
threads = []
done_counter = 0
exception_counter = 0
def worker(worker_id, prepopulate_worker_ids: bool = False):
try:
nonlocal done_counter
worker_id = worker_id
hash = worker_id
if prepopulate_worker_ids:
hash = (
2**32 - worker_id
) # use different hash for prepopulate_worker_ids
assert 0 <= hash < 2**64 # needs to be valid u64
store_event = {
"event_id": worker_id,
"data": {
"stored": {
"parent_hash": None,
"blocks": [
{
"block_hash": hash,
"tokens_hash": hash,
}
],
}
},
}
event_bytes = json.dumps(store_event).encode("utf-8")
radix_tree.apply_event(worker_id, event_bytes)
if not prepopulate_worker_ids:
done_counter += 1
except Exception as e:
print(f"Exception in worker {worker_id}: {e}")
nonlocal exception_counter
exception_counter += 1
if prepopulate_worker_ids:
for i in range(num_threads):
worker(i, prepopulate_worker_ids=True)
assert (
exception_counter == 0
), f"Warmup: expected 0 exceptions, got {exception_counter}"
for i in range(num_threads):
if is_threaded:
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()
else:
worker(i)
if is_threaded:
timeout = 10 # seconds
for t in threads:
t.join(timeout)
assert not t.is_alive(), "Thread timed out"
assert exception_counter == 0, f"Expected 0 exceptions, got {exception_counter}"
assert (
done_counter == num_threads
), f"Expected {num_threads} done, got {done_counter}"
for i in range(num_threads):
overlap_scores = radix_tree.find_matches([i])
assert overlap_scores.scores is not None
worker_key = (i, 0)
assert (
worker_key in overlap_scores.scores
), f"Worker {worker_key} not found in scores"
assert (
overlap_scores.scores[worker_key] == 1
), f"Expected score 1 for worker {worker_key}, got {overlap_scores.scores[worker_key]}"
# get all blocks
blocks = radix_tree.dump_tree_as_events()
expected_blocks = num_threads + (prepopulate_worker_ids * num_threads)
assert (
len(blocks) == expected_blocks
), f"Expected {expected_blocks} block events, got {len(blocks)}"
# remove single worker
radix_tree.remove_worker(0)
expected_blocks_after_removal = expected_blocks - (
2 if prepopulate_worker_ids else 1
)
blocks_after_removal = radix_tree.dump_tree_as_events()
assert (
len(blocks_after_removal) == expected_blocks_after_removal
), f"Expected {expected_blocks_after_removal} block events after removal, got {len(blocks_after_removal)}"
# TODO Figure out how to test with different kv_block_size # TODO Figure out how to test with different kv_block_size
# Right now I get an error in EventPublisher init when I run this test # Right now I get an error in EventPublisher init when I run this test
# back to back. It occurs when calling dynamo_llm_init and I think is related to the # back to back. It occurs when calling dynamo_llm_init and I think is related to the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment