Unverified Commit 6a728d10 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: remove kv indexers bindings (#6159)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent adaf1a39
......@@ -1632,7 +1632,6 @@ dependencies = [
"ahash",
"aho-corasick",
"akin",
"aligned-vec",
"anyhow",
"async-nats",
"async-stream",
......@@ -1649,7 +1648,6 @@ dependencies = [
"bytes",
"candle-core",
"chrono",
"cudarc",
"dashmap 5.5.3",
"derive-getters",
"derive_builder",
......@@ -1684,8 +1682,6 @@ dependencies = [
"ndarray",
"ndarray-interp",
"ndarray-npy",
"nix 0.26.4",
"nixl-sys",
"object_store",
"offset-allocator",
"oneshot",
......@@ -3941,15 +3937,6 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]]
name = "memoffset"
version = "0.9.1"
......@@ -4243,19 +4230,6 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if 1.0.4",
"libc",
"memoffset 0.7.1",
"pin-utils",
]
[[package]]
name = "nix"
version = "0.29.0"
......@@ -5502,7 +5476,7 @@ dependencies = [
"cfg-if 1.0.4",
"indoc",
"libc",
"memoffset 0.9.1",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
......
......@@ -163,8 +163,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::preprocessor::MediaDecoder>()?;
m.add_class::<llm::preprocessor::MediaFetcher>()?;
m.add_class::<llm::kv::OverlapScores>()?;
m.add_class::<llm::kv::KvIndexer>()?;
m.add_class::<llm::kv::ApproxKvIndexer>()?;
m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<llm::kv::RadixTree>()?;
m.add_class::<llm::kv::ZmqKvEventListener>()?;
......
......@@ -10,7 +10,6 @@ use tokio_stream::StreamExt;
use super::*;
use crate::Component;
use llm_rs::kv_router::indexer::KvIndexerInterface;
use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
use rs::pipeline::{AsyncEngine, SingleIn};
use tracing;
......@@ -678,186 +677,6 @@ impl Drop for RadixTree {
}
}
#[pyclass]
pub(crate) struct KvIndexer {
inner: Arc<llm_rs::kv_router::indexer::KvIndexer>,
}
#[pymethods]
impl KvIndexer {
#[new]
#[pyo3(signature = (component, kv_block_size, consumer_uuid=None))]
fn new(
component: Component,
kv_block_size: usize,
consumer_uuid: Option<String>,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let cancellation_token = component.inner.drt().runtime().child_token();
let kv_indexer_metrics =
llm_rs::kv_router::indexer::KvIndexerMetrics::from_component(&component.inner);
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new(
cancellation_token.clone(),
kv_block_size as u32,
kv_indexer_metrics,
)
.into();
// Use the shared start_kv_router_background function for event consumption
// Pass None for snapshot_tx and get_workers_tx to skip snapshot handling in Python bindings
llm_rs::kv_router::subscriber::start_kv_router_background(
component.inner.clone(),
consumer_uuid.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
inner.event_sender(),
inner.remove_worker_sender(),
None,
None,
cancellation_token,
None,
true,
)
.await
.map_err(to_pyerr)?;
Ok(Self { inner })
})
}
fn block_size(&self) -> usize {
self.inner.block_size() as usize
}
fn find_matches<'p>(&self, py: Python<'p>, sequence: Vec<u64>) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash> = sequence
.into_iter()
.map(llm_rs::kv_router::protocols::LocalBlockHash)
.collect();
let rs_overlap_scores = indexer
.find_matches(local_block_hashes)
.await
.map_err(to_pyerr)?;
Ok(OverlapScores {
inner: rs_overlap_scores,
})
})
}
fn find_matches_for_request<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
_lora_id: u64,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let rs_overlap_scores = indexer
.find_matches_for_request(token_ids.as_slice())
.await
.map_err(to_pyerr)?;
Ok(OverlapScores {
inner: rs_overlap_scores,
})
})
}
}
/// Bindings for the approximate KV indexer. This is a wrapper around KvIndexer
/// that uses TTL-based expiration and pruning instead of receiving KV events from workers.
#[pyclass]
pub(crate) struct ApproxKvIndexer {
inner: Arc<llm_rs::kv_router::indexer::KvIndexer>,
}
#[pymethods]
impl ApproxKvIndexer {
#[new]
#[pyo3(signature = (component, kv_block_size, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
fn new(
component: Component,
kv_block_size: usize,
router_ttl_secs: f64,
router_max_tree_size: usize,
router_prune_target_ratio: f64,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let cancellation_token = component.inner.drt().runtime().child_token();
let kv_indexer_metrics =
llm_rs::kv_router::indexer::KvIndexerMetrics::from_component(&component.inner);
// Build PruneConfig with the provided parameters
let prune_config = llm_rs::kv_router::approx::PruneConfig {
ttl: std::time::Duration::from_secs_f64(router_ttl_secs),
max_tree_size: router_max_tree_size,
prune_target_ratio: router_prune_target_ratio,
};
// Create KvIndexer with pruning enabled, but DO NOT subscribe to events
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new_with_frequency(
cancellation_token.clone(),
None, // expiration_duration - not used with prune_config
kv_block_size as u32,
kv_indexer_metrics,
Some(prune_config),
)
.into();
// Note: We deliberately do NOT call start_kv_router_background here
// because ApproxKvIndexer doesn't use KV events from workers
Ok(Self { inner })
})
}
fn block_size(&self) -> usize {
self.inner.block_size() as usize
}
fn find_matches_for_request<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let rs_overlap_scores = indexer
.find_matches_for_request(token_ids.as_slice())
.await
.map_err(to_pyerr)?;
Ok(OverlapScores {
inner: rs_overlap_scores,
})
})
}
#[pyo3(signature = (tokens, worker_id, dp_rank=0))]
fn process_routing_decision_for_request<'p>(
&self,
py: Python<'p>,
tokens: Vec<u32>,
worker_id: WorkerId,
dp_rank: DpRank,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
let block_size = self.inner.block_size();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker = llm_rs::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank);
let mut tokens_with_hashes = TokensWithHashes::new(tokens, block_size);
indexer
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
}
/// Helper function to create a KV router from an endpoint using the ModelManager
/// to ensure proper etcd registration.
/// Infers worker type using endpoint naming and router config:
......
......@@ -5,14 +5,12 @@
import logging
from dynamo._core import ApproxKvIndexer as ApproxKvIndexer
from dynamo._core import EngineType
from dynamo._core import EntrypointArgs as EntrypointArgs
from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpService as HttpService
from dynamo._core import KserveGrpcService as KserveGrpcService
from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvPushRouter as KvPushRouter
from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import LoRADownloader as LoRADownloader
......
......@@ -14,28 +14,16 @@
# limitations under the License.
import asyncio
import json
import threading
from typing import List
import pytest
from dynamo.llm import ApproxKvIndexer, KvEventPublisher, KvIndexer, RadixTree
from dynamo.runtime import Component, DistributedRuntime
from dynamo.llm import RadixTree
pytestmark = pytest.mark.pre_merge
@pytest.fixture
async def distributed_runtime():
"""Function-scoped runtime fixture for distributed runtime tests."""
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "etcd", "nats")
yield runtime
runtime.shutdown()
@pytest.mark.timeout(5) # Expected: ~1s, timeout set to 5x for safety
def test_radix_tree_binding():
"""Test RadixTree binding directly with store event and find matches"""
......@@ -201,109 +189,3 @@ def test_radix_tree_thread_safety(
assert (
len(blocks_after_removal) == expected_blocks_after_removal
), f"Expected {expected_blocks_after_removal} block events after removal, got {len(blocks_after_removal)}"
@pytest.mark.asyncio
@pytest.mark.timeout(5) # Expected: ~1s, timeout set to 5x for safety
async def test_event_handler(distributed_runtime):
kv_block_size = 32
namespace = "kv_test"
component = "event"
kv_listener = distributed_runtime.namespace(namespace).component(component)
# publisher
# Get actual worker_id from component (KvEventPublisher ignores the passed worker_id and uses component's connection_id)
# Create a dummy endpoint to access connection_id since Component doesn't expose it directly
dummy_endpoint = kv_listener.endpoint("dummy")
worker_id = dummy_endpoint.connection_id()
event_publisher = EventPublisher(kv_listener, worker_id, kv_block_size)
# indexer
indexer = KvIndexer(kv_listener, kv_block_size)
test_token = [3] * kv_block_size
lora_id = 0 # lora_id is not used in the indexer
scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores
event_publisher.store_event(test_token, lora_id)
# Wait for the event to be processed (sent asynchronously)
await asyncio.sleep(0.2)
scores = await indexer.find_matches_for_request(test_token, lora_id)
worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert scores.scores, "No scores found"
assert worker_key in scores.scores, f"Worker {worker_key} not found in scores"
assert (
scores.scores[worker_key] == 1
), f"Expected score 1, got {scores.scores[worker_key]}"
# Remove event and verify
event_publisher.remove_event()
await asyncio.sleep(0.2)
scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores, f"Scores still present: {scores.scores}"
@pytest.mark.asyncio
@pytest.mark.timeout(5) # Expected: ~1s, timeout set to 5x for safety
async def test_approx_kv_indexer(distributed_runtime):
"""Test ApproxKvIndexer with TTL-based block tracking"""
kv_block_size = 32
namespace = "kv_test"
component = "approx_kv"
kv_listener = distributed_runtime.namespace(namespace).component(component)
# Create ApproxKvIndexer with default TTL (120s)
indexer = ApproxKvIndexer(kv_listener, kv_block_size)
tokens = [0] * (kv_block_size * 2)
# Initially no matches
scores = await indexer.find_matches_for_request(tokens)
assert not scores.scores
worker_id = 0
# Process routing decision - this should add blocks to the indexer
await indexer.process_routing_decision_for_request(tokens, worker_id)
# Now we should have matches
scores = await indexer.find_matches_for_request(tokens)
assert scores.scores
worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert worker_key in scores.scores
assert scores.scores[worker_key] == 2 # 2 blocks (tokens is 2 blocks long)
class EventPublisher:
def __init__(self, component: Component, worker_id: int, kv_block_size: int):
self.publisher = KvEventPublisher(component, worker_id, kv_block_size)
# Counter for generating unique block hashes (event_id is now managed internally by publisher)
self.block_hash_counter = 0
self.block_hashes: List[int] = []
def store_event(self, tokens, lora_id):
# Parent hash should reference the last published block, not the current one
parent_hash = self.block_hashes[-1] if self.block_hashes else None
self.publisher.publish_stored(
tokens, # token_ids
[
len(tokens),
], # num_block_tokens
[
self.block_hash_counter,
], # block_hashes
lora_id, # lora_id
parent_hash, # parent_hash
)
self.block_hashes.append(self.block_hash_counter)
self.block_hash_counter += 1
def remove_event(self):
self.publisher.publish_removed(
[
self.block_hashes[-1],
], # block_hashes
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment