"examples/vscode:/vscode.git/clone" did not exist on "fa6a7f94d59d8c38287f690acd28766e6c772619"
Unverified Commit c61e0dd3 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: merge KvIndexer and ApproxKvIndexer (#4500)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 989e246e
......@@ -131,7 +131,25 @@ def parse_args():
action="store_false",
dest="use_kv_events",
default=os.environ.get("DYN_KV_EVENTS", "true").lower() != "false",
help="KV Router: Disable KV events. When set, uses ApproxKvRouter for predicting block creation/deletion based only on incoming requests at a timer. By default, KV events are enabled.",
help="KV Router: Disable KV events. When set, the router predicts cache state based on routing decisions with TTL-based expiration and pruning, rather than receiving events from workers. By default, KV events are enabled.",
)
parser.add_argument(
"--router-ttl",
type=float,
default=float(os.environ.get("DYN_ROUTER_TTL", "120.0")),
help="KV Router: Time-to-live in seconds for blocks when KV events are disabled. Only used when --no-kv-events is set. Can be set via DYN_ROUTER_TTL env var (default: 120.0).",
)
parser.add_argument(
"--router-max-tree-size",
type=int,
default=int(os.environ.get("DYN_ROUTER_MAX_TREE_SIZE", str(2**10))),
help="KV Router: Maximum tree size before pruning when KV events are disabled. Only used when --no-kv-events is set. Can be set via DYN_ROUTER_MAX_TREE_SIZE env var (default: 1024).",
)
parser.add_argument(
"--router-prune-target-ratio",
type=float,
default=float(os.environ.get("DYN_ROUTER_PRUNE_TARGET_RATIO", "0.8")),
help="KV Router: Target size ratio after pruning when KV events are disabled. Only used when --no-kv-events is set. Can be set via DYN_ROUTER_PRUNE_TARGET_RATIO env var (default: 0.8).",
)
parser.add_argument(
"--namespace",
......@@ -282,6 +300,9 @@ async def async_main():
router_snapshot_threshold=flags.router_snapshot_threshold,
router_reset_states=flags.router_reset_states,
router_track_active_blocks=flags.router_track_active_blocks,
router_ttl_secs=flags.router_ttl,
router_max_tree_size=flags.router_max_tree_size,
router_prune_target_ratio=flags.router_prune_target_ratio,
)
elif flags.router_mode == "random":
router_mode = RouterMode.Random
......
......@@ -185,7 +185,7 @@ def parse_args():
action="store_false",
dest="use_kv_events",
default=True,
help="KV Router: Disable KV events. When set, uses ApproxKvRouter for predicting block creation/deletion based only on incoming requests. By default, KV events are enabled.",
help="KV Router: Disable KV events. When set, the router predicts cache state based on routing decisions with TTL-based expiration and pruning, rather than receiving events from workers. By default, KV events are enabled.",
)
parser.add_argument(
......@@ -218,6 +218,27 @@ def parse_args():
help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing (default: True)",
)
parser.add_argument(
"--router-ttl-secs",
type=float,
default=120.0,
help="KV Router: TTL for blocks in seconds. Only used when --no-kv-events is set. Controls how long cached blocks are considered valid without explicit events (default: 120.0)",
)
parser.add_argument(
"--router-max-tree-size",
type=int,
default=2**10,
help="KV Router: Maximum tree size before pruning. Only used when --no-kv-events is set. When the indexer tree exceeds this size, pruning is triggered (default: 1024)",
)
parser.add_argument(
"--router-prune-target-ratio",
type=float,
default=0.8,
help="KV Router: Target size ratio after pruning (0.0-1.0). Only used when --no-kv-events is set. Determines how aggressively to prune the tree (default: 0.8)",
)
return parser.parse_args()
......@@ -244,7 +265,10 @@ async def worker(runtime: DistributedRuntime):
f"use_kv_events={args.use_kv_events}, "
f"router_replica_sync={args.router_replica_sync}, "
f"router_reset_states={args.router_reset_states}, "
f"router_track_active_blocks={args.router_track_active_blocks}"
f"router_track_active_blocks={args.router_track_active_blocks}, "
f"router_ttl_secs={args.router_ttl_secs}, "
f"router_max_tree_size={args.router_max_tree_size}, "
f"router_prune_target_ratio={args.router_prune_target_ratio}"
)
# Create KvRouter configuration
......@@ -256,6 +280,9 @@ async def worker(runtime: DistributedRuntime):
router_snapshot_threshold=args.router_snapshot_threshold,
router_reset_states=args.router_reset_states,
router_track_active_blocks=args.router_track_active_blocks,
router_ttl_secs=args.router_ttl_secs,
router_max_tree_size=args.router_max_tree_size,
router_prune_target_ratio=args.router_prune_target_ratio,
)
# Create service component - use "router" as component name
......
......@@ -148,7 +148,7 @@ The KV-aware routing arguments:
- `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 recovers the deterministic behavior where the min logit is picked.
- `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events.
- `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, the router uses KV events to track block creation and deletion from workers. If false, the router predicts cache state based on routing decisions with TTL-based expiration (default 120s) and pruning. Set false if your backend engine does not emit KV events.
### Request Migration
......
......@@ -210,8 +210,8 @@ The router uses KV events from workers by default to maintain an accurate global
- Recommended for production deployments
- **Without KV Events (--no-kv-events)**:
- Uses ApproxKvIndexer to estimate cached blocks from routing decisions
- Assumes blocks from recent requests remain cached
- Router predicts cache state based on routing decisions with TTL-based expiration and pruning
- Tracks blocks from recent requests with configurable time-to-live
- Reduces overhead at the cost of routing accuracy
- Suitable for testing or when event processing becomes a bottleneck
......
......@@ -21,7 +21,7 @@ The main KV-aware routing arguments:
- `--router-temperature`: Controls worker selection randomness through softmax sampling of router cost logits. A value of 0 (default) ensures deterministic selection of the lowest-cost worker, while higher values introduce more randomness.
- `--no-kv-events`: Disables KV event tracking. By default (when this flag is not provided), the router uses `KvIndexer` to monitor block creation and deletion events. When disabled with this flag, uses `ApproxKvIndexer`, which estimates cache hits based on a fixed time window (120s). Use this flag if your backend doesn't support KV events (or you are not confident in the accuracy or responsiveness of the events).
- `--no-kv-events`: Disables KV event tracking. By default (when this flag is not provided), the router uses KV events to monitor block creation and deletion from workers. When disabled with this flag, the router predicts cache state based on routing decisions with TTL-based expiration (default 120s) and pruning. Use this flag if your backend doesn't support KV events (or you are not confident in the accuracy or responsiveness of the events).
- `--router-replica-sync`: Disabled by default. Enables NATS-based synchronization of local routing decisions between router replicas. When enabled, routers share their active sequence information and local predictions of block usage, improving routing consistency across instances. Note that this does not sync the radix tree or cached KV block states themselves - those are synchronized through JetStream events
......@@ -33,10 +33,18 @@ The main KV-aware routing arguments:
- `--busy-threshold`: Threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines emit `ForwardPassMetrics`.
- `--router-ttl`: Time-to-live in seconds for blocks in the router's local cache predictions. Blocks older than this duration will be automatically expired and removed from the router's radix tree. Defaults to 120.0 seconds when `--no-kv-events` is used. This helps manage memory usage by removing stale cache predictions that are unlikely to be accurate.
- `--router-max-tree-size`: Maximum tree size (number of blocks) before pruning is triggered. When the total number of blocks in the radix tree exceeds this threshold, the router will prune the least recently used blocks. Defaults to 1048576 (2^20 blocks) when `--no-kv-events` is used. This prevents unbounded memory growth in long-running deployments.
- `--router-prune-target-ratio`: Target size ratio to prune down to when `--router-max-tree-size` is exceeded. For example, with a value of 0.8 (default) and max tree size of 1048576, the router will prune down to approximately 838860 blocks when the threshold is exceeded. Defaults to 0.8 when `--no-kv-events` is used. This creates headroom before the next pruning cycle.
>[!Note]
> State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported.
> State persistence is only available when KV events are enabled (default). When using `--no-kv-events`, state persistence is not currently supported.
>
> When `--kv-overlap-score-weight` is set to 0 or `--no-kv-events` is set, no KvIndexer will be launched to drain and process KV events. It's recommended to disable your backend workers from relaying events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases.
>
> The cli args `--router-ttl`, `--router-max-tree-size`, and `--router-prune-target-ratio` control local cache management when the router operates without receiving events from workers. When KV events are enabled (default), the router relies on worker-side eviction events and these parameters are ignored.
## Prerequisites and Limitations
......
......@@ -69,8 +69,8 @@ pub struct Flags {
pub router_temperature: Option<f64>,
/// KV Router: Whether to use KV events to maintain the view of cached blocks
/// If false, would use ApproxKvRouter for predicting block creation / deletion
/// based only on incoming requests at a timer.
/// If false, the router predicts cache state based on routing decisions
/// with TTL-based expiration and pruning, rather than receiving events from workers.
/// Default: true
#[arg(long)]
pub use_kv_events: Option<bool>,
......@@ -189,6 +189,9 @@ impl Flags {
// defaulting below args (no longer maintaining new flags for dynamo-run)
None,
None,
None,
None,
None,
),
)
}
......
......@@ -452,9 +452,12 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
(router_temperature >= 0.0).then_some(router_temperature),
Some(use_kv_events),
Some(router_replica_sync),
None,
None,
None,
None, // track_active_blocks
None, // router_snapshot_threshold
None, // router_reset_states
None, // router_ttl_secs
None, // router_max_tree_size
None, // router_prune_target_ratio
))
} else {
None
......
......@@ -41,7 +41,8 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1024, router_prune_target_ratio=0.8))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
router_temperature: f64,
......@@ -50,6 +51,9 @@ impl KvRouterConfig {
router_track_active_blocks: bool,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
router_ttl_secs: f64,
router_max_tree_size: usize,
router_prune_target_ratio: f64,
) -> Self {
KvRouterConfig {
inner: RsKvRouterConfig {
......@@ -60,6 +64,9 @@ impl KvRouterConfig {
router_track_active_blocks,
router_snapshot_threshold,
router_reset_states,
router_ttl_secs,
router_max_tree_size,
router_prune_target_ratio,
},
}
}
......
......@@ -723,33 +723,57 @@ impl KvIndexer {
}
}
/// Bindings for the approximate KV indexer. We need to exactly match the regular KV Indexer
/// interface, so that the router can switch between the two.
/// Bindings for the approximate KV indexer. This is a wrapper around KvIndexer
/// that uses TTL-based expiration and pruning instead of receiving KV events from workers.
#[pyclass]
pub(crate) struct ApproxKvIndexer {
inner: Arc<llm_rs::kv_router::approx::ApproxKvIndexer>,
inner: Arc<llm_rs::kv_router::indexer::KvIndexer>,
}
#[pymethods]
impl ApproxKvIndexer {
#[new]
fn new(component: Component, kv_block_size: usize, ttl_secs: f64) -> PyResult<Self> {
let ttl = tokio::time::Duration::from_secs_f64(ttl_secs);
let prune_config = Some(llm_rs::kv_router::approx::PruneConfig {
max_tree_size: 2usize.pow(20), // 2 ** 20 = 1048576
prune_target_ratio: 0.8,
});
let inner = Arc::new(llm_rs::kv_router::approx::ApproxKvIndexer::new(
component.inner.drt().runtime().child_token(),
#[pyo3(signature = (component, kv_block_size, router_ttl_secs=120.0, router_max_tree_size=1024, router_prune_target_ratio=0.8))]
fn new(
component: Component,
kv_block_size: usize,
router_ttl_secs: f64,
router_max_tree_size: usize,
router_prune_target_ratio: f64,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let cancellation_token = component.inner.drt().runtime().child_token();
let kv_indexer_metrics =
llm_rs::kv_router::indexer::KvIndexerMetrics::from_component(&component.inner);
// Build PruneConfig with the provided parameters
let prune_config = llm_rs::kv_router::approx::PruneConfig {
ttl: std::time::Duration::from_secs_f64(router_ttl_secs),
max_tree_size: router_max_tree_size,
prune_target_ratio: router_prune_target_ratio,
};
// Create KvIndexer with pruning enabled, but DO NOT subscribe to events
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new_with_frequency(
cancellation_token.clone(),
None, // expiration_duration - not used with prune_config
kv_block_size as u32,
ttl,
prune_config,
));
kv_indexer_metrics,
Some(prune_config),
)
.into();
// Note: We deliberately do NOT call start_kv_router_background here
// because ApproxKvIndexer doesn't use KV events from workers
Ok(Self { inner })
})
}
fn block_size(&self) -> u32 {
self.inner.block_size()
fn block_size(&self) -> usize {
self.inner.block_size() as usize
}
fn find_matches_for_request<'p>(
......
......@@ -613,33 +613,78 @@ class KvIndexer:
class ApproxKvIndexer:
"""
A KV Indexer that doesn't use KV cache events. It instead relies solely on the input tokens.
An approximate KV Indexer that doesn't receive KV cache events from workers.
Instead, it relies on routing decisions with TTL-based expiration and pruning
to estimate which blocks are cached on which workers.
This is useful when:
- Backend engines don't emit KV events
- You want to reduce event processing overhead
- Lower routing accuracy is acceptable
"""
def __init__(self, component: Component, kv_block_size: int, ttl_secs: float) -> None:
...
def __init__(
self,
component: Component,
kv_block_size: int,
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1024,
router_prune_target_ratio: float = 0.8,
) -> None:
"""
Create a `ApproxKvIndexer` object
Create an `ApproxKvIndexer` object
Args:
component: The component to associate with this indexer
kv_block_size: The KV cache block size
router_ttl_secs: TTL for blocks in seconds (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1024)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
"""
...
def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores:
def find_matches_for_request(
self, token_ids: List[int]
) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
Args:
token_ids: List of token IDs to find matches for
Returns:
OverlapScores containing worker matching scores and frequencies
"""
...
def block_size(self) -> int:
"""
Return the block size of the ApproxKvIndexer.
Returns:
The KV cache block size
"""
...
def process_routing_decision_for_request(self, tokens: List[int], lora_id: int, worker_id: int) -> None:
async def process_routing_decision_for_request(
self, tokens: List[int], worker_id: int, dp_rank: int = 0
) -> None:
"""
Notify the indexer that a token sequence has been sent to a specific worker.
Notify the indexer that a token sequence has been routed to a specific worker.
This updates the indexer's internal state to track which blocks are likely
cached on which workers based on routing decisions.
Args:
tokens: List of token IDs that were routed
worker_id: The worker ID the request was routed to
dp_rank: The data parallel rank (default: 0)
"""
...
class KvRecorder:
"""
A recorder for KV Router events.
......@@ -978,6 +1023,35 @@ class RouterConfig:
class KvRouterConfig:
"""Values for KV router"""
def __init__(
self,
overlap_score_weight: float = 1.0,
router_temperature: float = 0.0,
use_kv_events: bool = True,
router_replica_sync: bool = False,
router_track_active_blocks: bool = True,
router_snapshot_threshold: Optional[int] = 1000000,
router_reset_states: bool = False,
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1024,
router_prune_target_ratio: float = 0.8,
) -> None:
"""
Create a KV router configuration.
Args:
overlap_score_weight: Weight for overlap score in worker selection (default: 1.0)
router_temperature: Temperature for worker sampling via softmax (default: 0.0)
use_kv_events: Whether to use KV events from workers (default: True)
router_replica_sync: Enable replica synchronization (default: False)
router_track_active_blocks: Track active blocks for load balancing (default: True)
router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
router_reset_states: Reset router state on startup (default: False)
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1024)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
"""
...
async def register_llm(
......
......@@ -248,27 +248,32 @@ async def test_event_handler(distributed_runtime):
@pytest.mark.asyncio
async def test_approx_kv_indexer(distributed_runtime):
"""Test ApproxKvIndexer with TTL-based block tracking"""
kv_block_size = 32
namespace = "kv_test"
component = "approx_kv"
kv_listener = distributed_runtime.namespace(namespace).component(component)
indexer = ApproxKvIndexer(kv_listener, kv_block_size, 30.0)
# Create ApproxKvIndexer with default TTL (120s)
indexer = ApproxKvIndexer(kv_listener, kv_block_size)
tokens = [0] * (kv_block_size * 2)
# Initially no matches
scores = await indexer.find_matches_for_request(tokens)
assert not scores.scores
worker_id = 0
# Process routing decision - this should add blocks to the indexer
await indexer.process_routing_decision_for_request(tokens, worker_id)
# Now we should have matches
scores = await indexer.find_matches_for_request(tokens)
assert scores.scores
worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert worker_key in scores.scores
assert scores.scores[worker_key] == 2
assert scores.scores[worker_key] == 2 # 2 blocks (tokens is 2 blocks long)
class EventPublisher:
......
......@@ -36,7 +36,6 @@ pub use prefill_router::PrefillRouter;
use crate::{
kv_router::{
approx::ApproxKvIndexer,
approx::PruneConfig,
indexer::{
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
......@@ -53,6 +52,7 @@ use crate::{
model_card::ModelDeploymentCard,
preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput,
tokens::SequenceHash,
};
// [gluo TODO] shouldn't need to be public
......@@ -113,6 +113,15 @@ pub struct KvRouterConfig {
/// Whether to reset the router state on startup (default: false)
pub router_reset_states: bool,
/// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
pub router_ttl_secs: f64,
/// Maximum tree size before pruning (only used when use_kv_events is false, default: 1024)
pub router_max_tree_size: usize,
/// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
pub router_prune_target_ratio: f64,
}
impl Default for KvRouterConfig {
......@@ -125,6 +134,9 @@ impl Default for KvRouterConfig {
router_track_active_blocks: true,
router_snapshot_threshold: Some(1000000),
router_reset_states: false,
router_ttl_secs: 120.0,
router_max_tree_size: 1024,
router_prune_target_ratio: 0.8,
}
}
}
......@@ -141,6 +153,9 @@ impl KvRouterConfig {
track_active_blocks: Option<bool>,
router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>,
router_ttl_secs: Option<f64>,
router_max_tree_size: Option<usize>,
router_prune_target_ratio: Option<f64>,
) -> Self {
let default = Self::default();
Self {
......@@ -153,21 +168,20 @@ impl KvRouterConfig {
router_snapshot_threshold: router_snapshot_threshold
.unwrap_or(default.router_snapshot_threshold),
router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
router_prune_target_ratio: router_prune_target_ratio
.unwrap_or(default.router_prune_target_ratio),
}
}
}
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
/// Updates itself based on KV events emitted by backend workers.
/// Updates itself based on KV events emitted by backend workers or routing decisions.
/// Supports TTL-based expiration and size-based pruning.
/// Has the ability to persist and snapshot states.
KvIndexer(KvIndexer),
/// Predicts the cached blocks based on requests on a TTL basis.
/// Currently does not persist or snapshot states (WIP to enable that).
ApproxKvIndexer(ApproxKvIndexer),
/// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
/// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
None,
......@@ -180,7 +194,6 @@ impl Indexer {
) -> Result<OverlapScores, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::None => Ok(OverlapScores {
scores: HashMap::new(),
frequencies: Vec::new(),
......@@ -192,7 +205,6 @@ impl Indexer {
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
Indexer::None => {
panic!(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
......@@ -200,6 +212,22 @@ impl Indexer {
}
}
}
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => {
indexer
.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
Indexer::None => Ok(()),
}
}
}
/// A KvRouter only decides which worker you should use. It doesn't send you there.
......@@ -253,23 +281,26 @@ impl KvRouter {
let indexer = if kv_router_config.overlap_score_weight == 0.0 {
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer::None
} else if kv_router_config.use_kv_events {
} else {
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
Indexer::KvIndexer(KvIndexer::new(
cancellation_token.clone(),
block_size,
kv_indexer_metrics,
))
// If use_kv_events is false, enable TTL and pruning for approximate behavior
let prune_config = if !kv_router_config.use_kv_events {
Some(PruneConfig {
ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
max_tree_size: kv_router_config.router_max_tree_size,
prune_target_ratio: kv_router_config.router_prune_target_ratio,
})
} else {
// hard code 120 seconds for now
Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
None
};
Indexer::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token.clone(),
None, // expiration_duration for frequency tracking
block_size,
Duration::from_secs(120),
Some(PruneConfig {
max_tree_size: 2usize.pow(20), // 2 ** 20 = 1048576
prune_target_ratio: 0.8,
}),
kv_indexer_metrics,
prune_config,
))
};
......@@ -284,8 +315,10 @@ impl KvRouter {
)
.await?;
// Start unified background process if using KvIndexer
if let Indexer::KvIndexer(ref kv_indexer) = indexer {
// Start KV event subscriber background process (only when use_kv_events is enabled)
if kv_router_config.use_kv_events
&& let Indexer::KvIndexer(ref kv_indexer) = indexer
{
start_kv_router_background(
component.clone(),
consumer_uuid,
......@@ -343,12 +376,12 @@ impl KvRouter {
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
// Determine who needs seq_hashes
let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
let needs_process_routing = !self.kv_router_config.use_kv_events;
let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;
// Optimize cloning: only clone if both need it, otherwise move
let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
match (approx_indexer_needs_it, scheduler_needs_it) {
match (needs_process_routing, scheduler_needs_it) {
(true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
(true, false) => (Some(seq_hashes), None),
(false, true) => (None, Some(seq_hashes)),
......@@ -367,12 +400,12 @@ impl KvRouter {
)
.await?;
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer
// Process routing decision when not using KV events (approximate mode with TTL/pruning)
if needs_process_routing {
self.indexer
.process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
.await
.unwrap();
};
.await?;
}
let overlap_amount = overlap_scores
.scores
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Approximate KV Indexer
//! Pruning and TTL utilities for KV Indexers
//!
//! - This module implements an approximate KV indexer that can be used to find matches for a given sequence of tokens.
//! - It is designed to be used in conjunction with the KV router to find matches for a given sequence of tokens.
//!
//! # Overview
//!
//! - The Approximate KV Indexer, unlike the regular KV Indexer, does not depend on KV events.
//! - The approximate indexer depends only on the input tokens. We can use input tokens + our routing decision to approximate the radix trees across workers.
//!
//! - The thinking behind this is that if we send a request to a worker, and shortly after get a request with a similar prefix, odds
//! are that routing to the same worker will result in a large cache hit.
//! - Another benefit is the ability to bound the size of the radix tree, which is not possible if we were trying to accurately represent
//! the state of each worker.
//! This module provides utilities for managing TTL-based expiration and size-based pruning
//! of blocks in the radix tree. These utilities are used by the KvIndexer to manage
//! memory usage and keep the cache fresh.
use async_trait::async_trait;
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
use std::sync::OnceLock;
use tokio::sync::{mpsc, oneshot};
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::indexer::{
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent,
compute_block_hash_for_seq,
};
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, WorkerId, WorkerWithDpRank,
};
#[derive(Debug)]
struct MatchRequest {
/// Sequence of tokens.
sequence: Vec<LocalBlockHash>,
/// A channel to send the `OverlapScores` response.
resp: oneshot::Sender<OverlapScores>,
}
#[derive(Debug)]
struct RouterResult {
/// The worker (with dp_rank) that was selected.
worker: WorkerWithDpRank,
/// The local hashes of the tokens sent to the worker.
local_hashes: Vec<LocalBlockHash>,
/// The sequence hashes of the tokens sent to the worker.
sequence_hashes: Vec<u64>,
}
use crate::kv_router::indexer::KvRouterError;
use crate::kv_router::protocols::{ExternalSequenceBlockHash, WorkerWithDpRank};
/// Block entry to be inserted in the [`PruneManager::expirations`] heap.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
struct BlockEntry {
pub struct BlockEntry {
/// The key of the block entry.
key: ExternalSequenceBlockHash,
pub key: ExternalSequenceBlockHash,
/// The worker (with dp_rank) that stored this block.
worker: WorkerWithDpRank,
pub worker: WorkerWithDpRank,
/// The position of this block in the sequence (0-indexed).
seq_position: usize,
pub seq_position: usize,
}
impl PartialOrd for BlockEntry {
......@@ -85,6 +44,8 @@ impl Ord for BlockEntry {
#[derive(Debug, Clone)]
pub struct PruneConfig {
/// Time-to-live duration for blocks before they expire.
pub ttl: Duration,
/// The maximum tree size before pruning is considered.
pub max_tree_size: usize,
/// The target size ratio to prune down to when max_tree_size is exceeded.
......@@ -93,13 +54,23 @@ pub struct PruneConfig {
pub prune_target_ratio: f64,
}
impl Default for PruneConfig {
fn default() -> Self {
Self {
ttl: Duration::from_secs(120), // 120 seconds
max_tree_size: 2usize.pow(20), // 2^20 = 1048576
prune_target_ratio: 0.8, // Prune down to 80% of max
}
}
}
/// A data structure to manage a collection of timers, addressable by a key.
/// This is structured as a sort of "priority queue" of keys, where the priority is the expiration time.
/// It supports insertion as well as updating the expiration time of a key.
/// The [`PruneManager::expirations`] heap is lazily updated to reflect the true expiration times in [`PruneManager::timers`]
/// For now, we have a fixed expiration time for all keys.
#[derive(Debug)]
struct PruneManager<K: Clone + Hash + Eq + Ord> {
pub struct PruneManager<K: Clone + Hash + Eq + Ord> {
/// The source of truth. Maps a key to its current expiration instant.
timers: HashMap<K, Instant>,
......@@ -116,18 +87,19 @@ struct PruneManager<K: Clone + Hash + Eq + Ord> {
ttl: Duration,
/// The configuration for tree-size pruning.
prune_config: Option<PruneConfig>,
pub prune_config: Option<PruneConfig>,
}
impl<K: Clone + Hash + Eq + Ord> PruneManager<K> {
/// Creates a new, empty PruneManager.
pub fn new(ttl: Duration, threshold: usize, prune_config: Option<PruneConfig>) -> Self {
pub fn new(threshold: usize, prune_config: PruneConfig) -> Self {
let ttl = prune_config.ttl;
PruneManager {
timers: HashMap::new(),
expirations: BinaryHeap::new(),
ttl,
threshold,
prune_config,
prune_config: Some(prune_config),
}
}
......@@ -247,310 +219,12 @@ impl<K: Clone + Hash + Eq + Ord> PruneManager<K> {
}
}
pub struct ApproxKvIndexer {
/// A `CancellationToken` for managing shutdown.
cancel: CancellationToken,
/// A sender for `MatchRequest`s.
match_tx: mpsc::Sender<MatchRequest>,
/// A sender for `RouterResult`s.
route_tx: mpsc::Sender<RouterResult>,
/// A sender for remove worker requests.
remove_worker_tx: mpsc::Sender<WorkerId>,
/// A sender for dump requests.
dump_tx: mpsc::Sender<DumpRequest>,
/// A handle to the background task managing the KV store.
task: OnceLock<std::thread::JoinHandle<()>>,
/// The size of the KV block this indexer can handle.
kv_block_size: u32,
}
impl ApproxKvIndexer {
pub fn new(
token: CancellationToken,
kv_block_size: u32,
ttl: Duration,
prune_config: Option<PruneConfig>,
) -> Self {
let (match_tx, mut match_rx) = mpsc::channel::<MatchRequest>(2048);
let (route_tx, mut route_rx) = mpsc::channel::<RouterResult>(2048);
let (remove_worker_tx, mut remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (_get_workers_tx, mut get_workers_rx) =
mpsc::channel::<super::indexer::GetWorkersRequest>(16);
let (dump_tx, mut dump_rx) = mpsc::channel::<DumpRequest>(16);
let (prune_tx, mut prune_rx) = mpsc::channel::<()>(1);
let cancel_clone = token.clone();
let task = std::thread::spawn(move || {
// create a new tokio runtime which will only perform work on a single thread
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async move {
let mut trie = RadixTree::new();
// Use a reasonable threshold for ttl - can be made configurable if needed
let mut prune_manager: PruneManager<BlockEntry> = PruneManager::new(ttl, 50, prune_config.clone());
let mut event_id = 0;
loop {
// Create a future that sleeps until the next expiration time.
let expiry_fut = if let Some(next_expiry) = prune_manager.peek_next_expiry() {
tokio::time::sleep_until(next_expiry)
} else {
// If there are no timers, sleep forever.
tokio::time::sleep(Duration::MAX)
};
tokio::select! {
biased;
_ = cancel_clone.cancelled() => {
tracing::debug!("Approximate Indexer progress loop shutting down");
return;
}
Some(worker) = remove_worker_rx.recv() => {
trie.remove_worker(worker);
}
Some(get_workers_req) = get_workers_rx.recv() => {
let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers);
}
Some(_) = prune_rx.recv() => {
// The tree has exceeded the max tree size, so proceed with pruning.
if let Ok(pruned) = prune_manager.prune(trie.current_size()) {
pruned.iter().for_each(|p| {
event_id += 1;
let event = RouterEvent::new(
p.worker.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![p.key],
}),
dp_rank: p.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
});
}
}
Some(result) = route_rx.recv() => {
let hashes = result.local_hashes.iter().zip(result.sequence_hashes.iter());
let stored_event = KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData {
tokens_hash: *local_hash,
block_hash: ExternalSequenceBlockHash(*sequence_hash),
}).collect(),
});
event_id += 1;
let event = RouterEvent::new(
result.worker.worker_id,
KvCacheEvent {
event_id,
data: stored_event,
dp_rank: result.worker.dp_rank,
}
);
if trie.apply_event(event).is_ok() {
prune_manager.insert(result.sequence_hashes.iter().enumerate().map(|(idx, h)| BlockEntry {
key: ExternalSequenceBlockHash(*h),
worker: result.worker,
seq_position: idx,
}).collect());
// Check if we need to prune due to tree size exceeding max threshold.
if let Some(prune_config) = &prune_manager.prune_config {
let current_size = trie.current_size();
if current_size > prune_config.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
prune_config.max_tree_size
);
// Send a signal to the pruning receiver to schedule pruning.
if let Err(mpsc::error::TrySendError::Closed(_)) = prune_tx.try_send(()) {
tracing::error!("Failed to send prune schedule signal, prune receiver is closed");
}
}
}
}
}
Some(dump_req) = dump_rx.recv() => {
let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events);
}
Some(request) = match_rx.recv() => {
let scores = trie.find_matches(request.sequence, false);
request.resp.send(scores).unwrap();
}
_ = expiry_fut => {
let expired = prune_manager.pop_expired();
expired.iter().for_each(|e| {
event_id += 1;
let event = RouterEvent::new(
e.worker.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![e.key],
}),
dp_rank: e.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
});
}
}
}
});
});
let once = OnceLock::new();
once.set(task).unwrap();
Self {
cancel: token,
match_tx,
route_tx,
remove_worker_tx,
dump_tx,
task: once,
kv_block_size,
}
}
pub fn block_size(&self) -> u32 {
self.kv_block_size
}
/// Core function to process a routing decision with pre-computed hashes
pub async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
self.route_tx
.send(RouterResult {
worker,
local_hashes,
sequence_hashes,
})
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
/// Wrapper function that computes hashes from tokens and calls the core function
pub async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size);
let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None);
let sequence_hashes = sequence
.blocks()
.iter()
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
}
#[async_trait]
impl KvIndexerInterface for ApproxKvIndexer {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
let (resp_tx, resp_rx) = oneshot::channel();
let request = MatchRequest {
sequence,
resp: resp_tx,
};
if let Err(e) = self.match_tx.send(request).await {
tracing::error!(
"Failed to send match request: {:?}; the indexer maybe offline",
e
);
return Err(KvRouterError::IndexerOffline);
}
resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
async fn find_matches_for_request(
&self,
tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
self.find_matches(sequence).await
}
async fn apply_event(&mut self, _event: RouterEvent) {
panic!("Approximate Indexer does not support apply_event");
}
async fn remove_worker(&mut self, worker: WorkerId) {
self.remove_worker_tx.send(worker).await.unwrap();
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
if let Err(e) = self.dump_tx.send(dump_req).await {
tracing::error!("Failed to send dump request: {:?}", e);
return Err(KvRouterError::IndexerOffline);
}
resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
fn shutdown(&mut self) {
self.cancel.cancel();
if let Some(task) = self.task.take() {
task.join()
.expect("Failed to join approximate indexer task");
}
}
}
impl Drop for ApproxKvIndexer {
fn drop(&mut self) {
self.shutdown();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
use crate::kv_router::protocols::{WorkerId, WorkerWithDpRank};
use std::sync::Arc;
use tokio::time::{self, Duration, Instant};
use tokio_util::sync::CancellationToken;
......@@ -585,7 +259,12 @@ mod tests {
#[tokio::test]
async fn test_prune_manager_expiry() {
const TTL: Duration = Duration::from_millis(50);
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, None);
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: usize::MAX, // Effectively disable size-based pruning
prune_target_ratio: 0.5,
};
let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
pm.insert(vec![1, 2, 3]);
assert!(pm.get_expiry(&1).is_some());
......@@ -606,7 +285,12 @@ mod tests {
async fn test_prune_manager_update_resets_ttl() {
// Validate that reinserting an existing key extends its TTL and prevents premature expiry.
const TTL: Duration = Duration::from_millis(50);
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, None);
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
// Initial insert and capture the original expiry.
pm.insert(vec![42]);
......@@ -638,7 +322,7 @@ mod tests {
assert_eq!(expired_after, vec![42]);
}
/// End-to-end test for [`ApproxKvIndexer`]:
/// End-to-end test for [`KvIndexer`] with TTL:
/// 1. No matches before routing decision
/// 2. Matches appear after `process_routing_decision`
/// 3. Matches disappear after TTL expiry
......@@ -646,7 +330,19 @@ mod tests {
async fn test_approx_kv_indexer_basic_flow() {
const TTL: Duration = Duration::from_millis(200);
let cancel = CancellationToken::new();
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, None);
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let indexer = KvIndexer::new_with_frequency(
cancel.clone(),
None,
KV_BLOCK_SIZE,
metrics,
Some(prune_config),
);
let tokens: Vec<u32> = vec![1, 2, 3, 4]; // Exactly one KV block
let worker_id: WorkerId = 0;
......@@ -688,7 +384,19 @@ mod tests {
async fn test_remove_worker() {
const TTL: Duration = Duration::from_secs(5); // Large enough to avoid expiry during test
let cancel = CancellationToken::new();
let mut indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, None);
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let mut indexer = KvIndexer::new_with_frequency(
cancel.clone(),
None,
KV_BLOCK_SIZE,
metrics,
Some(prune_config),
);
let tokens: Vec<u32> = vec![10, 11, 12, 13];
let worker_id: WorkerId = 7;
......@@ -727,7 +435,19 @@ mod tests {
const TTL: Duration = Duration::from_secs(5); // Large enough to avoid expiry during test
let cancel = CancellationToken::new();
let mut indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, None);
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let mut indexer = KvIndexer::new_with_frequency(
cancel.clone(),
None,
KV_BLOCK_SIZE,
metrics,
Some(prune_config),
);
let tokens: Vec<u32> = vec![100, 101, 102, 103];
let worker_0: WorkerId = 30;
......@@ -785,7 +505,19 @@ mod tests {
const TTL: Duration = Duration::from_secs(5);
let cancel = CancellationToken::new();
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, None);
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let indexer = KvIndexer::new_with_frequency(
cancel.clone(),
None,
KV_BLOCK_SIZE,
metrics,
Some(prune_config),
);
// Sequence A : single block
let seq_a: Vec<u32> = vec![1, 2, 3, 4];
......@@ -831,7 +563,19 @@ mod tests {
const TTL: Duration = Duration::from_secs(5);
let cancel = CancellationToken::new();
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, None);
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let indexer = KvIndexer::new_with_frequency(
cancel.clone(),
None,
KV_BLOCK_SIZE,
metrics,
Some(prune_config),
);
let tokens: Vec<u32> = vec![9, 8, 7, 6];
let worker_0: WorkerId = 21;
......@@ -888,11 +632,12 @@ mod tests {
async fn test_prune_manager_no_prune_when_within_bounds() {
const TTL: Duration = Duration::from_secs(10);
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: 100,
prune_target_ratio: 0.5,
};
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, Some(prune_config));
let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
// Insert 50 keys (well below max_tree_size of 100)
pm.insert((0..50).collect());
......@@ -912,11 +657,12 @@ mod tests {
async fn test_prune_manager_prune_removes_oldest_first() {
const TTL: Duration = Duration::from_secs(10);
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: 10,
prune_target_ratio: 0.5,
};
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, Some(prune_config));
let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
// Insert keys one at a time with delays to ensure different timestamps
for i in 1..=15 {
......@@ -945,7 +691,14 @@ mod tests {
#[tokio::test]
async fn test_prune_manager_prune_fails_without_config() {
const TTL: Duration = Duration::from_secs(10);
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, None);
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
// Temporarily set prune_config to None to test the error case
pm.prune_config = None;
pm.insert(vec![1, 2, 3]);
......@@ -975,7 +728,7 @@ mod tests {
assert!(entry1 < entry2);
}
/// End-to-end test for [`ApproxKvIndexer`] with pruning
/// End-to-end test for [`KvIndexer`] with TTL and pruning
/// 0. Max tree size is 5, target size is 2 (prune_target_ratio = 0.4)
/// 1. Insert 5 blocks (at max_tree_size but not exceeding)
/// 2. Verify all 5 blocks are present
......@@ -986,12 +739,20 @@ mod tests {
async fn test_approx_indexer_e2e_pruning() {
const TTL: Duration = Duration::from_secs(60); // Long TTL to avoid expiry
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: 5, // Very small to trigger pruning quickly
prune_target_ratio: 0.4, // target size is 5 * 0.4 = 2
};
let cancel = CancellationToken::new();
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, Some(prune_config));
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let indexer = KvIndexer::new_with_frequency(
cancel.clone(),
None,
KV_BLOCK_SIZE,
metrics,
Some(prune_config),
);
let worker = WorkerWithDpRank::from_worker_id(42);
......@@ -1059,11 +820,12 @@ mod tests {
async fn test_prune_manager_prune_reinsertion_updates_position() {
const TTL: Duration = Duration::from_secs(10);
let prune_config = PruneConfig {
ttl: TTL,
max_tree_size: 5,
prune_target_ratio: 0.8,
};
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, Some(prune_config));
let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
// Insert keys
for i in 1..=10 {
......
......@@ -54,8 +54,9 @@ use xxhash_rust::xxh3;
pub const XXH3_SEED: u64 = 1337;
use crate::kv_router::approx::{BlockEntry, PruneConfig, PruneManager};
use crate::kv_router::protocols::*;
use crate::tokens::SequenceHash;
use crate::tokens::{SequenceHash, TokenBlockSequence};
/// Errors that can occur in the KV Router.
#[derive(Debug, thiserror::Error)]
......@@ -833,6 +834,39 @@ pub trait KvIndexerInterface {
///
/// A vector of RouterEvents representing the current state of the tree.
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError>;
/// Process a routing decision with pre-computed hashes.
///
/// ### Arguments
///
/// * `worker` - The worker (with dp_rank) that was selected.
/// * `local_hashes` - The local hashes of the tokens sent to the worker.
/// * `sequence_hashes` - The sequence hashes of the tokens sent to the worker.
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError>;
/// Process a routing decision for a request with tokens.
///
/// ### Arguments
///
/// * `tokens` - A vector of `u32` tokens.
/// * `worker` - The worker (with dp_rank) that was selected.
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError>;
}
/// A request to process a routing decision.
struct RoutingDecisionRequest {
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
}
/// The KV Indexer, managing the KV store and handling events and match requests.
......@@ -849,6 +883,8 @@ pub struct KvIndexer {
get_workers_tx: mpsc::Sender<GetWorkersRequest>,
/// A sender for dump requests.
dump_tx: mpsc::Sender<DumpRequest>,
/// A sender for routing decision requests.
routing_tx: mpsc::Sender<RoutingDecisionRequest>,
/// A handle to the background task managing the KV store.
task: OnceLock<std::thread::JoinHandle<()>>,
/// The size of the KV block this indexer can handle.
......@@ -862,6 +898,8 @@ impl KvIndexer {
///
/// * `token` - A `CancellationToken` for managing shutdown.
/// * `expiration_duration` - The amount of time that block usage should be buffered.
/// * `ttl` - The time-to-live for blocks before they expire.
/// * `prune_config` - Configuration for tree-size based pruning.
///
/// ### Returns
///
......@@ -871,12 +909,15 @@ impl KvIndexer {
expiration_duration: Option<Duration>,
kv_block_size: u32,
metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>,
) -> Self {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (get_workers_tx, get_workers_rx) = mpsc::channel::<GetWorkersRequest>(16);
let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16);
let (routing_tx, mut routing_rx) = mpsc::channel::<RoutingDecisionRequest>(2048);
let (prune_tx, mut prune_rx) = mpsc::channel::<()>(1);
let cancel_clone = token.clone();
let task = std::thread::spawn(move || {
......@@ -894,7 +935,22 @@ impl KvIndexer {
let mut get_workers_rx = get_workers_rx;
let mut dump_rx = dump_rx;
let mut trie = RadixTree::new_with_frequency(expiration_duration);
// Create PruneManager if prune_config is specified
let mut prune_manager = prune_config.map(|config| {
PruneManager::<BlockEntry>::new(50, config)
});
let mut event_id_counter = 0u64;
loop {
// Create a future that sleeps until the next expiration time
let expiry_fut = if let Some(ref pm) = prune_manager
&& let Some(next_expiry) = pm.peek_next_expiry() {
tokio::time::sleep_until(next_expiry)
} else {
tokio::time::sleep(Duration::MAX)
};
tokio::select! {
biased;
......@@ -912,10 +968,59 @@ impl KvIndexer {
let _ = get_workers_req.resp.send(workers);
}
Some(_) = prune_rx.recv() => {
// Tree size-based pruning triggered
let Some(ref mut pm) = prune_manager else { continue };
let Ok(pruned) = pm.prune(trie.current_size()) else { continue };
for p in pruned {
event_id_counter += 1;
let event = RouterEvent::new(
p.worker.worker_id,
KvCacheEvent {
event_id: event_id_counter,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![p.key],
}),
dp_rank: p.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
}
}
Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event);
let result = trie.apply_event(event.clone());
let result_is_ok = result.is_ok();
metrics.increment_event_applied(event_type, result);
// Track blocks in PruneManager if TTL is enabled and event was stored successfully
let Some(ref mut pm) = prune_manager else { continue };
if !result_is_ok { continue };
let KvCacheEventData::Stored(ref store_data) = event.event.data else { continue };
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
let block_entries: Vec<BlockEntry> = store_data.blocks.iter().enumerate().map(|(idx, block)| {
BlockEntry {
key: block.block_hash,
worker,
seq_position: idx,
}
}).collect();
pm.insert(block_entries);
// Check if we need to prune due to tree size
let Some(ref pc) = pm.prune_config else { continue };
let current_size = trie.current_size();
if current_size > pc.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
pc.max_tree_size
);
let _ = prune_tx.try_send(());
}
}
Some(dump_req) = dump_rx.recv() => {
......@@ -923,10 +1028,81 @@ impl KvIndexer {
let _ = dump_req.resp.send(events);
}
Some(routing_req) = routing_rx.recv() => {
// Process routing decisions when TTL/pruning is enabled
let Some(ref mut pm) = prune_manager else { continue };
event_id_counter += 1;
let hashes = routing_req.local_hashes.iter().zip(routing_req.sequence_hashes.iter());
let stored_event = KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData {
tokens_hash: *local_hash,
block_hash: ExternalSequenceBlockHash(*sequence_hash),
}).collect(),
});
let event = RouterEvent::new(
routing_req.worker.worker_id,
KvCacheEvent {
event_id: event_id_counter,
data: stored_event,
dp_rank: routing_req.worker.dp_rank,
}
);
if trie.apply_event(event).is_err() {
continue;
}
let block_entries: Vec<BlockEntry> = routing_req.sequence_hashes.iter().enumerate().map(|(idx, h)| {
BlockEntry {
key: ExternalSequenceBlockHash(*h),
worker: routing_req.worker,
seq_position: idx,
}
}).collect();
pm.insert(block_entries);
// Check if we need to prune due to tree size
let Some(ref pc) = pm.prune_config else { continue };
let current_size = trie.current_size();
if current_size > pc.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
pc.max_tree_size
);
let _ = prune_tx.try_send(());
}
}
Some(req) = match_rx.recv() => {
let matches = trie.find_matches(req.sequence, req.early_exit);
let _ = req.resp.send(matches);
}
_ = expiry_fut => {
// TTL-based expiry triggered
let Some(ref mut pm) = prune_manager else { continue };
let expired = pm.pop_expired();
for e in expired {
event_id_counter += 1;
let event = RouterEvent::new(
e.worker.worker_id,
KvCacheEvent {
event_id: event_id_counter,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![e.key],
}),
dp_rank: e.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
}
}
}
}
});
......@@ -944,6 +1120,7 @@ impl KvIndexer {
remove_worker_tx,
get_workers_tx,
dump_tx,
routing_tx,
task: once,
kv_block_size,
}
......@@ -958,7 +1135,7 @@ impl KvIndexer {
kv_block_size: u32,
metrics: Arc<KvIndexerMetrics>,
) -> Self {
Self::new_with_frequency(token, None, kv_block_size, metrics)
Self::new_with_frequency(token, None, kv_block_size, metrics, None)
}
/// Get a sender for `RouterEvent`s.
......@@ -1066,6 +1243,40 @@ impl KvIndexerInterface for KvIndexer {
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
self.routing_tx
.send(RoutingDecisionRequest {
worker,
local_hashes,
sequence_hashes,
})
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size);
let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None);
let sequence_hashes = sequence
.blocks()
.iter()
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
}
impl Drop for KvIndexer {
......@@ -1107,6 +1318,7 @@ pub struct KvIndexerSharded {
request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
dump_tx: Vec<mpsc::Sender<DumpRequest>>,
routing_tx: Vec<mpsc::Sender<RoutingDecisionRequest>>,
tasks: Vec<JoinHandle<()>>,
}
......@@ -1118,6 +1330,8 @@ impl KvIndexerSharded {
/// * `token` - A `CancellationToken` for managing shutdown.
/// * `shards` - A list of kvindexer shards.
/// * `expiration_duration` - The amount of time that block usage should be buffered.
/// * `ttl` - The time-to-live for blocks before they expire.
/// * `prune_config` - Configuration for tree-size based pruning.
///
/// ### Returns
///
......@@ -1128,6 +1342,7 @@ impl KvIndexerSharded {
expiration_duration: Option<Duration>,
kv_block_size: u32,
metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>,
) -> Self {
let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
let worker_counts: Vec<usize> = vec![0; num_shards];
......@@ -1135,7 +1350,8 @@ impl KvIndexerSharded {
let mut event_tx = Vec::new();
let mut remove_worker_tx = Vec::new();
let mut get_workers_tx = Vec::new();
let mut dump_tx = Vec::new(); // Add dump channels
let mut dump_tx = Vec::new();
let mut routing_tx = Vec::new();
let mut tasks = Vec::new();
let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);
......@@ -1146,15 +1362,20 @@ impl KvIndexerSharded {
mpsc::channel::<WorkerId>(16);
let (shard_get_workers_tx, mut shard_get_workers_rx) =
mpsc::channel::<GetWorkersRequest>(16);
let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16); // Add dump channel
let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16);
let (shard_routing_tx, mut shard_routing_rx) =
mpsc::channel::<RoutingDecisionRequest>(2048);
let (shard_prune_tx, mut shard_prune_rx) = mpsc::channel::<()>(1);
let mut shard_broadcast_rx = request_broadcast_tx.subscribe();
let cancel = token.clone();
let metrics = metrics.clone();
let prune_config_clone = prune_config.clone();
event_tx.push(shard_event_tx);
remove_worker_tx.push(shard_remove_worker_tx);
get_workers_tx.push(shard_get_workers_tx);
dump_tx.push(shard_dump_tx); // Store dump sender
dump_tx.push(shard_dump_tx);
routing_tx.push(shard_routing_tx);
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
......@@ -1164,7 +1385,22 @@ impl KvIndexerSharded {
tasks.push(std::thread::spawn(move || {
runtime.block_on(async move {
let mut trie = RadixTree::new_with_frequency(expiration_duration);
// Create PruneManager if prune_config is specified
let mut prune_manager = prune_config_clone.map(|config| {
PruneManager::<BlockEntry>::new(50, config)
});
let mut event_id_counter = 0u64;
loop {
// Create a future that sleeps until the next expiration time
let expiry_fut = if let Some(ref pm) = prune_manager
&& let Some(next_expiry) = pm.peek_next_expiry() {
tokio::time::sleep_until(next_expiry)
} else {
tokio::time::sleep(Duration::MAX)
};
tokio::select! {
biased;
......@@ -1182,10 +1418,109 @@ impl KvIndexerSharded {
let _ = get_workers_req.resp.send(workers);
}
Some(_) = shard_prune_rx.recv() => {
// Tree size-based pruning triggered
let Some(ref mut pm) = prune_manager else { continue };
let Ok(pruned) = pm.prune(trie.current_size()) else { continue };
for p in pruned {
event_id_counter += 1;
let event = RouterEvent::new(
p.worker.worker_id,
KvCacheEvent {
event_id: event_id_counter,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![p.key],
}),
dp_rank: p.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
}
}
Some(event) = shard_event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event);
let result = trie.apply_event(event.clone());
let result_is_ok = result.is_ok();
metrics.increment_event_applied(event_type, result);
// Track blocks in PruneManager if TTL is enabled and event was stored successfully
let Some(ref mut pm) = prune_manager else { continue };
if !result_is_ok { continue };
let KvCacheEventData::Stored(ref store_data) = event.event.data else { continue };
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
let block_entries: Vec<BlockEntry> = store_data.blocks.iter().enumerate().map(|(idx, block)| {
BlockEntry {
key: block.block_hash,
worker,
seq_position: idx,
}
}).collect();
pm.insert(block_entries);
// Check if we need to prune due to tree size
let Some(ref pc) = pm.prune_config else { continue };
let current_size = trie.current_size();
if current_size > pc.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
pc.max_tree_size
);
let _ = shard_prune_tx.try_send(());
}
}
Some(routing_req) = shard_routing_rx.recv() => {
// Process routing decisions when TTL/pruning is enabled
let Some(ref mut pm) = prune_manager else { continue };
event_id_counter += 1;
let hashes = routing_req.local_hashes.iter().zip(routing_req.sequence_hashes.iter());
let stored_event = KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData {
tokens_hash: *local_hash,
block_hash: ExternalSequenceBlockHash(*sequence_hash),
}).collect(),
});
let event = RouterEvent::new(
routing_req.worker.worker_id,
KvCacheEvent {
event_id: event_id_counter,
data: stored_event,
dp_rank: routing_req.worker.dp_rank,
}
);
if trie.apply_event(event).is_err() {
continue;
}
let block_entries: Vec<BlockEntry> = routing_req.sequence_hashes.iter().enumerate().map(|(idx, h)| {
BlockEntry {
key: ExternalSequenceBlockHash(*h),
worker: routing_req.worker,
seq_position: idx,
}
}).collect();
pm.insert(block_entries);
// Check if we need to prune due to tree size
let Some(ref pc) = pm.prune_config else { continue };
let current_size = trie.current_size();
if current_size > pc.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
pc.max_tree_size
);
let _ = shard_prune_tx.try_send(());
}
}
Some(dump_req) = shard_dump_rx.recv() => {
......@@ -1199,6 +1534,27 @@ impl KvIndexerSharded {
tracing::trace!("Failed to send match response: {:?}", e);
}
}
_ = expiry_fut => {
// TTL-based expiry triggered
let Some(ref mut pm) = prune_manager else { continue };
let expired = pm.pop_expired();
for e in expired {
event_id_counter += 1;
let event = RouterEvent::new(
e.worker.worker_id,
KvCacheEvent {
event_id: event_id_counter,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![e.key],
}),
dp_rank: e.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
}
}
}
}
});
......@@ -1215,7 +1571,8 @@ impl KvIndexerSharded {
event_tx,
request_broadcast_tx,
remove_worker_tx,
dump_tx, // Add dump_tx field
dump_tx,
routing_tx,
tasks,
}
}
......@@ -1230,7 +1587,7 @@ impl KvIndexerSharded {
kv_block_size: u32,
metrics: Arc<KvIndexerMetrics>,
) -> Self {
Self::new_with_frequency(token, num_shards, None, kv_block_size, metrics)
Self::new_with_frequency(token, num_shards, None, kv_block_size, metrics, None)
}
}
......@@ -1358,6 +1715,47 @@ impl KvIndexerInterface for KvIndexerSharded {
Ok(all_events)
}
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
// Route to the appropriate shard based on worker assignment
let shard_idx = self
.worker_assignments
.get(&worker.worker_id)
.copied()
.unwrap_or(0);
self.routing_tx[shard_idx]
.send(RoutingDecisionRequest {
worker,
local_hashes,
sequence_hashes,
})
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size);
let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None);
let sequence_hashes = sequence
.blocks()
.iter()
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
}
impl Drop for KvIndexerSharded {
......@@ -2104,6 +2502,7 @@ mod tests {
Some(expiration),
kv_block_size,
metrics,
None,
));
} else {
kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
......@@ -2112,6 +2511,7 @@ mod tests {
Some(expiration),
kv_block_size,
metrics,
None,
));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment