Unverified Commit cbedd1db authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] cache-aware load-balancing router v1 (#2114)

parent ad47749b
import itertools
import json
import os
import random
import string
import threading
import time
from argparse import ArgumentParser
from pathlib import Path
from typing import Union
from tqdm import tqdm
import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenize
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
random.seed(42)
def gen_prompt(tokenizer, token_num):
all_available_tokens = list(tokenizer.get_vocab().values())
......@@ -24,12 +27,34 @@ def gen_prompt(tokenizer, token_num):
return ret
def get_cache_path(args):
# Create cache directory under ~/.cache/sglang
cache_dir = Path.home() / ".cache" / "sglang"
# Create a unique cache filename based on the arguments that affect generation
cache_key = f"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json"
return cache_dir / cache_key
def gen_arguments(args, tokenizer):
multi_qas = [
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
for _ in range(args.num_qa)
]
for i in range(args.num_qa):
cache_path = get_cache_path(args)
# Try to load from cache first
if cache_path.exists():
print(f"Loading cached arguments from {cache_path}")
with open(cache_path, "r") as f:
return json.load(f)
print("Generating new arguments...")
# First progress bar for system prompts
multi_qas = []
for _ in tqdm(range(args.num_qa), desc="Generating system prompts"):
multi_qas.append(
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
)
# Nested progress bars for QA pairs
for i in tqdm(range(args.num_qa), desc="Generating QA pairs"):
qas = multi_qas[i]["qas"]
for j in range(args.turns):
qas.append(
......@@ -38,6 +63,13 @@ def gen_arguments(args, tokenizer):
"new_tokens": args.len_a,
}
)
# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(cache_path, "w") as f:
json.dump(multi_qas, f)
print(f"Cached arguments saved to {cache_path}")
return multi_qas
......@@ -45,7 +77,7 @@ def gen_arguments(args, tokenizer):
def multi_turns(s, system_prompt, qas):
s += system_prompt
for qa in qas:
for i, qa in enumerate(qas):
s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
......@@ -62,7 +94,7 @@ def main(args):
multi_qas,
temperature=0,
backend=backend,
num_threads=args.parallel,
num_threads="auto",
progress_bar=True,
)
latency = time.time() - tic
......@@ -75,7 +107,6 @@ def main(args):
value = {
"task": "multi_turn_system_prompt_chat",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,
......
......@@ -727,9 +727,9 @@ def sample_generated_shared_prefix_requests(
total_input_tokens = 0
total_output_tokens = 0
for group_idx in range(num_groups):
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
system_prompt = system_prompts[group_idx]
for prompt_idx in range(prompts_per_group):
for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions"):
question = questions[group_idx * prompts_per_group + prompt_idx]
full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt))
......
......@@ -48,9 +48,13 @@ def run_eval(args):
# Select backend
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
if args.data_path is None:
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
else:
filename = args.data_path
lines = list(read_jsonl(filename))
# Construct prompts
......
......@@ -591,6 +591,20 @@ dependencies = [
"syn",
]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "deranged"
version = "0.3.11"
......@@ -904,6 +918,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "hashbrown"
version = "0.15.1"
......@@ -1226,7 +1246,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da"
dependencies = [
"equivalent",
"hashbrown",
"hashbrown 0.15.1",
]
[[package]]
......@@ -2097,7 +2117,9 @@ dependencies = [
"actix-web",
"bytes",
"clap",
"dashmap",
"futures-util",
"http 1.1.0",
"pyo3",
"rand",
"reqwest",
......
......@@ -24,6 +24,8 @@ futures-util = "0.3"
serde_json = "1.0"
pyo3 = { version = "0.22.5", features = ["extension-module"] }
tokenizers = { version = "0.20.3", features = ["http"] }
dashmap = "6.1.0"
http = "1.1.0"
[profile.release]
lto = "thin"
......
......@@ -46,6 +46,9 @@ pip install <path-to-wheel>
#### Option B: Development Mode
For development purposes, you can install the package in editable mode:
Warning: Using editable python binding can suffer from performance degradation!! Please build a fresh wheel for every update if you want to test performance.
```bash
pip install -e .
```
......
from sglang_router import PolicyType, Router
router = Router(
worker_urls=[
"http://localhost:30000",
"http://localhost:30001",
]
)
router.start()
import argparse
import os
import signal
import subprocess
import sys
import time
from typing import Dict, List
import requests
from sglang_router import PolicyType, Router
# Global processes list for cleanup
_processes: List[subprocess.Popen] = []
def cleanup_processes(signum=None, frame=None):
"""Cleanup function to kill all worker processes."""
print("\nCleaning up processes...")
for process in _processes:
try:
# Kill the entire process group
pgid = os.getpgid(process.pid)
os.killpg(pgid, signal.SIGKILL)
process.wait()
except:
pass
sys.exit(1)
# Register signal handlers
signal.signal(signal.SIGINT, cleanup_processes)
signal.signal(signal.SIGTERM, cleanup_processes)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Launch SGLang Router Server")
parser.add_argument(
"--host", type=str, default="localhost", help="Host address to bind the server"
)
parser.add_argument(
"--port", type=int, default=30000, help="Base port number for workers"
)
parser.add_argument(
"--dp",
type=int,
default=2,
help="Number of worker processes (degree of parallelism)",
)
parser.add_argument(
"--model-path", type=str, required=True, help="Path to the model"
)
parser.add_argument(
"--local-tokenizer-path",
type=str,
required=True,
help="Path to the local tokenizer",
)
return parser.parse_args()
def launch_workers(args) -> tuple[List[subprocess.Popen], List[str]]:
"""Launch all worker processes concurrently using subprocess."""
processes = []
worker_urls = []
# Launch each worker process
for i in range(args.dp):
port = args.port + i
url = f"http://{args.host}:{port}"
worker_urls.append(url)
# TODO: replace this with launch_server, and move this file to sglang/ because it depends on sglang
# We don't
command = f"export CUDA_VISIBLE_DEVICES={i}; python -m sglang.launch_server --model-path {args.model_path} --host {args.host} --port {port}"
print(command)
process = subprocess.Popen(command, shell=True)
processes.append(process)
_processes.append(process) # Add to global list for cleanup
return processes, worker_urls
def wait_for_healthy_workers(worker_urls: List[str], timeout: int = 300) -> bool:
"""Block until all workers are healthy or timeout is reached."""
start_time = time.time()
healthy_workers: Dict[str, bool] = {url: False for url in worker_urls}
while time.time() - start_time < timeout:
print("checking healthiness...")
all_healthy = True
for url in worker_urls:
if not healthy_workers[url]: # Only check workers that aren't healthy yet
try:
response = requests.get(f"{url}/health")
if response.status_code == 200:
print(f"Worker at {url} is healthy")
healthy_workers[url] = True
else:
all_healthy = False
except requests.RequestException:
all_healthy = False
if all_healthy:
print("All workers are healthy!")
return True
time.sleep(5)
# If we get here, we've timed out
unhealthy_workers = [url for url, healthy in healthy_workers.items() if not healthy]
print(f"Timeout waiting for workers: {unhealthy_workers}")
return False
def main():
"""Main function to launch the router and workers."""
args = parse_args()
processes = None
try:
# Launch all workers concurrently
processes, worker_urls = launch_workers(args)
# Block until all workers are healthy
if not wait_for_healthy_workers(worker_urls):
raise RuntimeError("Failed to start all workers")
# Initialize and start the router
router = Router(
worker_urls=worker_urls,
policy=PolicyType.ApproxTree,
tokenizer_path=args.local_tokenizer_path,
)
print("Starting router...")
router.start()
# Keep the main process running
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("\nShutting down...")
except Exception as e:
print(f"Error: {e}")
finally:
# Cleanup: Kill all worker processes
if processes:
for process in processes:
process.kill()
if __name__ == "__main__":
main()
import argparse
import dataclasses
import sys
from typing import List, Optional
from sglang_router import Router
from sglang_router_rs import PolicyType
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
worker_urls: List[str]
host: str = "127.0.0.1"
port: int = 30000
# Routing policy
policy: str = "cache_aware"
cache_threshold: float = 0.5
cache_routing_prob: float = 1.0
eviction_interval: int = 60
max_tree_size: int = 2**24
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
use_router_prefix: bool = False,
exclude_host_port: bool = False,
):
"""
Add router-specific arguments to an argument parser.
Args:
parser: The argument parser to add arguments to
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
"""
prefix = "router-" if use_router_prefix else ""
# Worker configuration
if not exclude_host_port:
parser.add_argument(
"--host",
type=str,
default=RouterArgs.host,
help="Host address to bind the router server",
)
parser.add_argument(
"--port",
type=int,
default=RouterArgs.port,
help="Port number to bind the router server",
)
parser.add_argument(
"--worker-urls",
type=str,
nargs="+",
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
)
# Routing policy configuration
parser.add_argument(
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware"],
help="Load balancing policy to use",
)
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
default=RouterArgs.cache_threshold,
help="Cache threshold (0.0-1.0) for cache-aware routing",
)
parser.add_argument(
f"--{prefix}cache-routing-prob",
type=float,
default=RouterArgs.cache_routing_prob,
help="Probability of using cache-aware routing (0.0-1.0)",
)
parser.add_argument(
f"--{prefix}eviction-interval",
type=int,
default=RouterArgs.eviction_interval,
help="Interval in seconds between cache eviction operations",
)
parser.add_argument(
f"--{prefix}max-tree-size",
type=int,
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
@classmethod
def from_cli_args(
cls, args: argparse.Namespace, use_router_prefix: bool = False
) -> "RouterArgs":
"""
Create RouterArgs instance from parsed command line arguments.
Args:
args: Parsed command line arguments
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
return cls(
worker_urls=args.worker_urls,
host=args.host,
port=args.port,
policy=getattr(args, f"{prefix}policy"),
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
cache_routing_prob=getattr(args, f"{prefix}cache_routing_prob"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
)
def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum."""
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
}
return policy_map[policy_str]
def launch_router(args: argparse.Namespace) -> Optional[Router]:
"""
Launch the SGLang router with the configuration from parsed arguments.
Args:
args: Namespace object containing router configuration
Can be either raw argparse.Namespace or converted RouterArgs
Returns:
Router instance if successful, None if failed
"""
try:
# Convert to RouterArgs if needed
if not isinstance(args, RouterArgs):
router_args = RouterArgs.from_cli_args(args)
else:
router_args = args
router = Router(
worker_urls=router_args.worker_urls,
policy=policy_from_str(router_args.policy),
host=router_args.host,
port=router_args.port,
cache_threshold=router_args.cache_threshold,
cache_routing_prob=router_args.cache_routing_prob,
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
)
router.start()
return router
except Exception as e:
print(f"Error starting router: {e}", file=sys.stderr)
return None
class CustomHelpFormatter(
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
):
"""Custom formatter that preserves both description formatting and shows defaults"""
pass
def parse_router_args(args: List[str]) -> RouterArgs:
"""Parse command line arguments and return RouterArgs instance."""
parser = argparse.ArgumentParser(
description="""SGLang Router - High-performance request distribution across worker nodes
Usage:
This launcher enables starting a router with individual worker instances. It is useful for
multi-node setups or when you want to start workers and router separately.
Examples:
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --cache-routing-prob 0.5
""",
formatter_class=CustomHelpFormatter,
)
RouterArgs.add_cli_args(parser, use_router_prefix=False)
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
def main() -> None:
router_args = parse_router_args(sys.argv[1:])
router = launch_router(router_args)
if router is None:
sys.exit(1)
if __name__ == "__main__":
main()
import argparse
import copy
import multiprocessing as mp
import os
import signal
import sys
import time
from typing import List
import requests
from sglang_router.launch_router import RouterArgs, launch_router
from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs, prepare_server_args
from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback
# Create new process group
def run_server(server_args, dp_rank):
os.setpgrp() # Create new process group
# Set DP_RANK environment variable
os.environ["DP_RANK"] = str(dp_rank)
launch_server(server_args)
def launch_server_process(
server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
"""Launch a single server process with the given args and port."""
server_args = copy.deepcopy(server_args)
server_args.port = worker_port
server_args.base_gpu_id = dp_id * server_args.tp_size
server_args.dp_size = 1
proc = mp.Process(target=run_server, args=(server_args, dp_id))
proc.start()
return proc
def cleanup_processes(processes: List[mp.Process]):
"""Clean up all processes using process groups."""
print("\nCleaning up processes...")
for proc in processes:
if proc.is_alive():
try:
# Kill the entire process group
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
# Give processes some time to terminate gracefully
proc.join(timeout=3)
# If process is still alive, force kill
if proc.is_alive():
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except ProcessLookupError:
pass # Process already terminated
def setup_signal_handlers(cleanup_func):
"""Setup handlers for various termination signals."""
def signal_handler(signum, frame):
cleanup_func()
sys.exit(1)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, "SIGQUIT"):
signal.signal(signal.SIGQUIT, signal_handler)
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint."""
start_time = time.time()
url = f"http://{host}:{port}/health"
while time.time() - start_time < timeout:
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
return True
except requests.exceptions.RequestException:
pass
time.sleep(1)
return False
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += 1
return available_ports
def main():
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn")
parser = argparse.ArgumentParser(
description="Launch SGLang router and server processes"
)
ServerArgs.add_cli_args(parser)
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
parser.add_argument(
"--router-dp-worker-base-port",
type=int,
default=31000,
help="Base port number for data parallel workers",
)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Find available ports for workers
worker_ports = find_available_ports(
args.router_dp_worker_base_port, server_args.dp_size
)
# Start server processes
server_processes = []
try:
# Launch server processes
for i, worker_port in enumerate(worker_ports):
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)
# Setup cleanup handler
setup_signal_handlers(lambda: cleanup_processes(server_processes))
# Wait for all servers to be healthy
all_healthy = True
for port in worker_ports:
if not wait_for_server_health(server_args.host, port):
print(f"Server on port {port} failed to become healthy")
all_healthy = False
break
if not all_healthy:
print("Not all servers are healthy. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
print("All servers are healthy. Starting router...")
# Update router args with worker URLs
router_args.worker_urls = [
f"http://{server_args.host}:{port}" for port in worker_ports
]
# Start the router
router = launch_router(router_args)
if router is None:
print("Failed to start router. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
except KeyboardInterrupt:
print("\nReceived shutdown signal...")
except Exception as e:
print(f"Error occurred: {e}")
print(get_exception_traceback())
finally:
cleanup_processes(server_processes)
if __name__ == "__main__":
main()
......@@ -9,16 +9,23 @@ class Router:
A high-performance router for distributing requests across worker nodes.
Args:
worker_urls: List of URLs for worker nodes that will handle requests
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
policy: Load balancing policy to use. Options:
- PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.ApproxTree: Tree-based routing using tokenizer similarity
host: Host address to bind the router server
port: Port number to bind the router server
tokenizer_path: Path to tokenizer model file (required for ApproxTree policy)
cache_threshold: Caching threshold value between 0-1
- PolicyType.CacheAware: Distribute requests in cache-aware fashion
host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5
cache_routing_prob: Probability of using cache-aware routing (0.0-1.0). Default 1.0 for
full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven
workloads, use a lower value to better distribute requests
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
"""
def __init__(
......@@ -27,17 +34,20 @@ class Router:
policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1",
port: int = 3001,
tokenizer_path: Optional[str] = None,
cache_threshold: float = 0.50,
cache_routing_prob: float = 1.0,
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
):
self._router = _Router(
worker_urls=worker_urls,
policy=policy,
host=host,
port=port,
tokenizer_path=tokenizer_path,
cache_threshold=cache_threshold,
cache_routing_prob=cache_routing_prob,
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
)
def start(self) -> None:
......
// Python Binding
use pyo3::prelude::*;
pub mod router;
mod server;
pub mod server;
pub mod tree;
#[pyclass(eq)]
......@@ -9,7 +8,7 @@ pub mod tree;
pub enum PolicyType {
Random,
RoundRobin,
ApproxTree,
CacheAware,
}
#[pyclass]
......@@ -18,8 +17,10 @@ struct Router {
port: u16,
worker_urls: Vec<String>,
policy: PolicyType,
tokenizer_path: Option<String>,
cache_threshold: Option<f32>,
cache_threshold: f32,
cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
}
#[pymethods]
......@@ -30,33 +31,30 @@ impl Router {
policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"),
port = 3001,
tokenizer_path = None,
cache_threshold = Some(0.50)
cache_threshold = 0.50,
cache_routing_prob = 1.0,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24)
))]
fn new(
worker_urls: Vec<String>,
policy: PolicyType,
host: String,
port: u16,
tokenizer_path: Option<String>,
cache_threshold: Option<f32>,
cache_threshold: f32,
cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
) -> PyResult<Self> {
// Validate required parameters for approx_tree policy
if matches!(policy, PolicyType::ApproxTree) {
if tokenizer_path.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"tokenizer_path is required for approx_tree policy",
));
}
}
Ok(Router {
host,
port,
worker_urls,
policy,
tokenizer_path,
cache_threshold,
cache_routing_prob,
eviction_interval_secs,
max_tree_size,
})
}
......@@ -68,14 +66,11 @@ impl Router {
let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig {
tokenizer_path: self
.tokenizer_path
.clone()
.expect("tokenizer_path is required for approx_tree policy"),
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
cache_threshold: self.cache_threshold,
cache_routing_prob: self.cache_routing_prob,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
};
......
// src/main.rs
use clap::Parser;
use clap::ValueEnum;
// declare child modules
mod router;
mod server;
mod tree;
use crate::router::PolicyConfig;
use sglang_router_rs::{router::PolicyConfig, server};
#[derive(Debug, Clone, ValueEnum)]
pub enum PolicyType {
Random,
RoundRobin,
ApproxTree,
CacheAware,
}
#[derive(Parser, Debug)]
......@@ -21,44 +17,70 @@ struct Args {
#[arg(
long,
default_value = "127.0.0.1",
help = "Host address to bind the server to"
help = "Host address to bind the router server to. Default: 127.0.0.1"
)]
host: String,
#[arg(long, default_value_t = 3001, help = "Port number to listen on")]
#[arg(
long,
default_value_t = 3001,
help = "Port number to bind the router server to. Default: 3001"
)]
port: u16,
#[arg(
long,
value_delimiter = ',',
help = "Comma-separated list of worker URLs to distribute requests to"
help = "Comma-separated list of worker URLs that will handle the requests. Each URL should include the protocol, host, and port (e.g., http://worker1:8000,http://worker2:8000)"
)]
worker_urls: Vec<String>,
#[arg(
long,
default_value_t = PolicyType::RoundRobin,
default_value_t = PolicyType::CacheAware,
value_enum,
help = "Load balancing policy to use: random, round_robin, or approx_tree"
help = "Load balancing policy to use for request distribution:\n\
- random: Randomly select workers\n\
- round_robin: Distribute requests in round-robin fashion\n\
- cache_aware: Distribute requests in cache-aware fashion\n"
)]
policy: PolicyType,
#[arg(
long,
default_value_t = 0.5,
requires = "policy",
required_if_eq("policy", "cache_aware"),
help = "Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5"
)]
cache_threshold: f32,
#[arg(
long,
default_value_t = 1.0,
requires = "policy",
required_if_eq("policy", "cache_aware"),
help = "Probability of using cache-aware routing (0.0-1.0). Default 1.0 for full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven workloads, use a lower value to better distribute requests"
)]
cache_routing_prob: f32,
#[arg(
long,
default_value_t = 60,
requires = "policy",
required_if_eq("policy", "approx_tree"),
help = "Path to the tokenizer file, required when using approx_tree policy"
required_if_eq("policy", "cache_aware"),
help = "Interval in seconds between cache eviction operations in cache-aware routing. Default: 60"
)]
tokenizer_path: Option<String>,
eviction_interval_secs: u64,
#[arg(
long,
default_value = "0.50",
default_value_t = 2usize.pow(24),
requires = "policy",
required_if_eq("policy", "approx_tree"),
help = "Cache threshold (0.0-1.0) for approx_tree routing. Routes to cached worker if match rate exceeds threshold, otherwise routes to shortest queue worker"
required_if_eq("policy", "cache_aware"),
help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
)]
cache_threshold: Option<f32>,
max_tree_size: usize,
}
impl Args {
......@@ -66,14 +88,11 @@ impl Args {
match self.policy {
PolicyType::Random => PolicyConfig::RandomConfig,
PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => PolicyConfig::ApproxTreeConfig {
tokenizer_path: self
.tokenizer_path
.clone()
.expect("tokenizer_path is required for approx_tree policy"),
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
PolicyType::CacheAware => PolicyConfig::CacheAwareConfig {
cache_threshold: self.cache_threshold,
cache_routing_prob: self.cache_routing_prob,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
}
}
......
use crate::tree::RadixTree;
use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
use futures_util::TryStreamExt;
use futures_util::{Stream, StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
use tokenizers::tokenizer::Tokenizer;
use std::thread;
use std::time::Duration;
#[derive(Debug)]
pub enum Router {
......@@ -18,34 +21,88 @@ pub enum Router {
Random {
worker_urls: Vec<String>,
},
ApproxTree {
CacheAware {
/*
Cache-Aware Load Balancing Router
This router combines two strategies to optimize both cache utilization and request distribution:
1. Cache-Aware Routing (Approximate Tree)
2. Load Balancing (Shortest Queue)
For each incoming request, the router chooses between these strategies:
- With probability P: Uses cache-aware routing
- With probability (1-P): Uses load balancing
where P is configured via `cache_routing_prob`
Strategy Details:
1. Cache-Aware Routing (Approximate Tree)
-------------------------------------------
This strategy maintains an approximate radix tree for each worker based on request history,
eliminating the need for direct cache state queries. The tree stores raw text characters
instead of token IDs to avoid tokenization overhead.
Process:
a. For each request, find the worker with the highest prefix match
b. If match rate > cache_threshold:
Route to the worker with highest match (likely has relevant data cached)
c. If match rate ≤ cache_threshold:
Route to the worker with smallest tree size (most available cache capacity)
d. Background maintenance:
Periodically evict least recently used leaf nodes to prevent memory overflow
2. Load Balancing (Shortest Queue)
-------------------------------------------
This strategy tracks pending request counts per worker and routes new requests
to the least busy worker for optimal load distribution.
Configuration Parameters:
------------------------
1. cache_routing_prob: (float, 0.0 to 1.0)
- 0.0: Exclusively use load balancing
- 1.0: Exclusively use cache-aware routing
- Between 0-1: Probability of using cache-aware routing vs load balancing
2. cache_threshold: (float, 0.0 to 1.0)
Minimum prefix match ratio to use highest-match routing.
Below this threshold, routes to worker with most available cache space.
3. eviction_interval_secs: (integer)
Interval between LRU eviction cycles for the approximate trees.
4. max_tree_size: (integer)
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Vec<String>,
// TODO: don't lock the whole tree
url_to_tree: Arc<Mutex<HashMap<String, RadixTree>>>,
tokenizer: Tokenizer,
url_to_count: Arc<Mutex<HashMap<String, usize>>>,
tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
cache_threshold: f32,
cache_routing_prob: f32,
_eviction_thread: Option<thread::JoinHandle<()>>, // Store thread handle
},
}
#[derive(Debug)]
pub enum PolicyConfig {
RandomConfig,
RoundRobinConfig,
ApproxTreeConfig {
tokenizer_path: String,
CacheAwareConfig {
cache_threshold: f32,
cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
},
}
fn get_token_ids_from_request(body: &Bytes, tokenizer: &Tokenizer) -> Vec<u32> {
fn get_text_from_request(body: &Bytes) -> String {
// 1. convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
// 2. get the text field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
// 3. tokenize the text field
let tokens = tokenizer.encode(text, false).unwrap();
tokens.get_ids().to_vec()
return text.to_string();
}
impl Router {
......@@ -56,25 +113,56 @@ impl Router {
worker_urls,
current_index: std::sync::atomic::AtomicUsize::new(0),
},
PolicyConfig::ApproxTreeConfig {
tokenizer_path,
PolicyConfig::CacheAwareConfig {
cache_threshold,
cache_routing_prob,
eviction_interval_secs,
max_tree_size,
} => {
let mut url_to_tree = HashMap::new();
let mut url_to_count = HashMap::new();
let mut running_queue = HashMap::new();
for url in &worker_urls {
running_queue.insert(url.clone(), 0);
}
let mut processed_queue = HashMap::new();
for url in &worker_urls {
processed_queue.insert(url.clone(), 0);
}
let tree = Arc::new(Mutex::new(Tree::new()));
let running_queue = Arc::new(Mutex::new(running_queue));
let processed_queue = Arc::new(Mutex::new(processed_queue));
// Create background eviction thread
let tree_clone = Arc::clone(&tree);
let processed_queue_clone = Arc::clone(&processed_queue);
let eviction_thread = thread::spawn(move || {
loop {
// Sleep for the specified interval
thread::sleep(Duration::from_secs(eviction_interval_secs));
let locked_tree_clone = tree_clone.lock().unwrap();
// Run eviction
locked_tree_clone.evict_tenant_data(max_tree_size);
// Print the process queue
let locked_processed_queue = processed_queue_clone.lock().unwrap();
println!("Processed Queue: {:?}", locked_processed_queue);
}
});
for url in &worker_urls {
url_to_tree.insert(url.clone(), RadixTree::new());
url_to_count.insert(url.clone(), 0);
tree.lock().unwrap().insert(&"".to_string(), url);
}
Router::ApproxTree {
Router::CacheAware {
worker_urls,
url_to_tree: Arc::new(Mutex::new(url_to_tree)),
// TODO: rust ::from_pretrained cannot load from local file, so use ::from_file to load local file
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
url_to_count: Arc::new(Mutex::new(url_to_count)),
tree,
running_queue,
processed_queue,
cache_threshold,
cache_routing_prob,
_eviction_thread: Some(eviction_thread),
}
}
}
......@@ -84,7 +172,7 @@ impl Router {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::ApproxTree { worker_urls, .. } => {
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.is_empty() {
None
} else {
......@@ -100,10 +188,7 @@ impl Router {
req: HttpRequest,
body: Bytes,
) -> HttpResponse {
let mut input_ids: Vec<u32> = Vec::new();
if let Router::ApproxTree { tokenizer, .. } = self {
input_ids = get_token_ids_from_request(&body, tokenizer);
}
let text = get_text_from_request(&body);
let worker_url = match self {
Router::RoundRobin {
......@@ -125,78 +210,73 @@ impl Router {
worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
}
Router::ApproxTree {
Router::CacheAware {
worker_urls,
url_to_tree,
url_to_count,
tree,
running_queue,
processed_queue,
cache_threshold,
cache_routing_prob,
..
} => {
// TODO: pipeline the locks. Release one earlier.
// even though the tree is thread-safe, we still put a lock to ensure the whole op (tree read + queue read + tree write + queue write) is atomic to handle some edge cases (e.g. multiple requests with long prefix entering at the same time)
let mut max_matched_rate = 0.0;
let mut max_matched_idx = 0;
let mut tree = tree.lock().unwrap();
let mut running_queue = running_queue.lock().unwrap();
let locked_url_to_tree = url_to_tree.lock().unwrap();
// Generate a random float between 0 and 1 for probability check
let sampled_p: f32 = rand::random();
// 1. Find the highest matched worker
for (i, url) in worker_urls.iter().enumerate() {
let tree = locked_url_to_tree.get(url).unwrap();
let matched = tree.prefix_match(&input_ids[..]).len();
let matched_rate = matched as f32 / input_ids.len() as f32;
let selected_url = if sampled_p < *cache_routing_prob {
// Cache-aware routing logic
let (matched_text, matched_worker) = tree.prefix_match(&text);
let matched_rate =
matched_text.chars().count() as f32 / text.chars().count() as f32;
if matched_rate > max_matched_rate {
max_matched_rate = matched_rate;
max_matched_idx = i;
}
}
if matched_rate > *cache_threshold {
matched_worker.to_string()
} else {
let m_map: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
// 2. If the rate is higher than the threshold, select the worker. If not, select the worker with the shortest queue
if max_matched_rate > *cache_threshold {
worker_urls[max_matched_idx].clone()
} else {
// pick the shortest queue from url_to_count
let locked_url_to_count = url_to_count.lock().unwrap();
let mut min_count = std::usize::MAX;
let mut min_count_id = 0;
for (i, url) in worker_urls.iter().enumerate() {
let count = locked_url_to_count.get(url).unwrap();
if *count < min_count {
min_count = *count;
min_count_id = i;
}
println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
tree.get_smallest_tenant()
}
} else {
// Shortest queue routing logic
running_queue
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone())
};
worker_urls[min_count_id].clone()
}
// Update running queue
let count = running_queue.get_mut(&selected_url).unwrap();
*count += 1;
// Update processed queue
let mut locked_processed_queue = processed_queue.lock().unwrap();
let count = locked_processed_queue.get_mut(&selected_url).unwrap();
*count += 1;
// Update tree with the new request
tree.insert(&text, &selected_url);
selected_url
}
};
if let Router::ApproxTree {
url_to_tree,
url_to_count,
..
} = self
{
// Insert input_ids to the tree
let mut locked_url_to_tree = url_to_tree.lock().unwrap();
let selected_tree = locked_url_to_tree.get_mut(&worker_url).unwrap();
selected_tree.insert(&input_ids[..]);
let mut locked_url_to_count = url_to_count.lock().unwrap();
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
*count += 1;
}
// Check if client requested streaming
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false);
let res = match client
.post(format!("{}/generate", worker_url))
.post(format!("{}/generate", worker_url.clone()))
.header(
"Content-Type",
req.headers()
......@@ -216,23 +296,53 @@ impl Router {
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream {
// TODO: do the correction on the tree based on the cached input_ids
if let Router::ApproxTree { url_to_count, .. } = self {
let mut locked_url_to_count = url_to_count.lock().unwrap();
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
*count -= 1;
}
match res.bytes().await {
// For non-streaming requests, get response first
let response = match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(),
};
// Then decrement running queue counter if using CacheAware
if let Router::CacheAware { running_queue, .. } = self {
if let Ok(mut queue) = running_queue.lock() {
if let Some(count) = queue.get_mut(&worker_url) {
*count = count.saturating_sub(1);
}
}
}
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
let worker_url = worker_url.clone();
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(
res.bytes_stream()
.map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read stream")
})
.inspect(move |bytes| {
let bytes = bytes.as_ref().unwrap();
if bytes
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]")
{
let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1);
// print
// println!("streaming is done!!")
}
}),
)
} else {
// TODO: do the correction on the tree based on the cached input_ids. The streaming might be tricker to handle
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(res.bytes_stream().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read string")
actix_web::error::ErrorInternalServerError("Failed to read stream")
}))
}
}
......
......@@ -76,6 +76,7 @@ pub async fn startup(
) -> std::io::Result<()> {
println!("Starting server on {}:{}", host, port);
println!("Worker URLs: {:?}", worker_urls);
println!("Policy Config: {:?}", policy_config);
// Create client once with configuration
let client = reqwest::Client::builder()
......
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use rand::distributions::{Alphanumeric, DistString};
use rand::thread_rng;
use std::cmp::min;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::HashMap;
use std::mem;
use std::sync::Arc;
use std::sync::RwLock;
use std::thread;
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
type NodeRef = Arc<Node>;
#[derive(Debug)]
pub struct Node {
pub children: HashMap<u32, Node>, // the key is first id of the child because each child must have unique first id
pub ids: Vec<u32>,
pub count: u32,
struct Node {
children: DashMap<char, NodeRef>,
text: RwLock<String>,
tenant_last_access_time: DashMap<String, u128>,
parent: RwLock<Option<NodeRef>>,
}
#[derive(Debug)]
pub struct RadixTree {
pub root: Node,
pub struct Tree {
root: NodeRef,
// TODO: Char Count per tenant
pub tenant_char_count: DashMap<String, usize>,
}
fn common_prefix_len(a: &[u32], b: &[u32]) -> usize {
let mut i = 0;
while i < a.len() && i < b.len() && a[i] == b[i] {
i += 1;
// For the heap
struct EvictionEntry {
timestamp: u128,
tenant: String,
node: NodeRef,
}
impl Eq for EvictionEntry {}
impl PartialOrd for EvictionEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.timestamp.cmp(&other.timestamp))
}
i
}
impl Default for RadixTree {
fn default() -> Self {
Self::new()
impl Ord for EvictionEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.timestamp.cmp(&other.timestamp)
}
}
impl RadixTree {
impl PartialEq for EvictionEntry {
fn eq(&self, other: &Self) -> bool {
self.timestamp == other.timestamp
}
}
// For char operations
// Note that in rust, `.len()` or slice is operated on the "byte" level. It causes issues for UTF-8 characters because one character might use multiple bytes.
// https://en.wikipedia.org/wiki/UTF-8
fn shared_prefix_count(a: &str, b: &str) -> usize {
let mut i = 0;
let mut a_iter = a.chars();
let mut b_iter = b.chars();
loop {
match (a_iter.next(), b_iter.next()) {
(Some(a_char), Some(b_char)) if a_char == b_char => {
i += 1;
}
_ => break,
}
}
return i;
}
fn slice_by_chars(s: &str, start: usize, end: usize) -> String {
s.chars().skip(start).take(end - start).collect()
}
impl Tree {
/*
Thread-safe multi tenant radix tree
1. Storing data for multiple tenants (the overlap of multiple radix tree)
2. Node-level lock to enable concurrent acesss on nodes
3. Leaf LRU eviction based on tenant access time
*/
pub fn new() -> Self {
RadixTree {
root: Node {
children: HashMap::new(),
ids: Vec::new(),
count: 0,
},
Tree {
root: Arc::new(Node {
children: DashMap::new(),
text: RwLock::new("".to_string()),
tenant_last_access_time: DashMap::new(),
parent: RwLock::new(None),
}),
tenant_char_count: DashMap::new(),
}
}
pub fn insert(&mut self, input_ids: &[u32]) {
let mut curr = &mut self.root;
curr.count += 1;
pub fn insert(&self, text: &str, tenant: &str) {
// Insert text into tree with given tenant
let mut curr = Arc::clone(&self.root);
let mut curr_idx = 0;
let input_ids_len = input_ids.len();
while curr_idx < input_ids_len {
let first_id = &input_ids[curr_idx];
// TODO: changing this get_mut causes error
if curr.children.contains_key(first_id) {
let child = curr.children.get_mut(first_id).unwrap();
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);
if prefix_len == child.ids.len() {
// move curr to child
curr = child;
curr.count += 1;
curr_idx += prefix_len;
} else {
// split child
// [child]->... => [child]->[new child]->...
let new_child = Node {
// to avoid clone: replace child.children with default value (empty vector) and return the original value
children: mem::take(&mut child.children),
ids: child.ids[prefix_len..].to_vec(),
count: child.count,
};
child.ids = child.ids[..prefix_len].to_vec();
child.children = HashMap::new();
child.children.insert(new_child.ids[0], new_child);
curr = child;
curr.count += 1;
curr_idx += prefix_len;
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis();
curr.tenant_last_access_time
.insert(tenant.to_string(), timestamp_ms);
self.tenant_char_count
.entry(tenant.to_string())
.or_insert(0);
let mut prev = Arc::clone(&self.root);
let text_count = text.chars().count();
while curr_idx < text_count {
let first_char = text.chars().nth(curr_idx).unwrap();
curr = prev;
// dashmap.entry locks the entry until the op is done
// if using contains_key + insert, there will be an issue that
// 1. "apple" and "app" entered at the same time
// 2. and get inserted to the dashmap concurrently, so only one is inserted
match curr.children.entry(first_char) {
Entry::Vacant(entry) => {
/*
no matched
[curr]
becomes
[curr] => [new node]
*/
let curr_text = slice_by_chars(text, curr_idx, text_count);
let curr_text_count = curr_text.chars().count();
let new_node = Arc::new(Node {
children: DashMap::new(),
text: RwLock::new(curr_text),
tenant_last_access_time: DashMap::new(),
parent: RwLock::new(Some(Arc::clone(&curr))),
});
// Increment char count when creating new node with tenant
self.tenant_char_count
.entry(tenant.to_string())
.and_modify(|count| *count += curr_text_count)
.or_insert(curr_text_count);
new_node
.tenant_last_access_time
.insert(tenant.to_string(), timestamp_ms);
entry.insert(Arc::clone(&new_node));
prev = Arc::clone(&new_node);
curr_idx = text_count;
}
Entry::Occupied(mut entry) => {
// matched
let matched_node = entry.get().clone();
let matched_node_text = matched_node.text.read().unwrap().to_owned();
let matched_node_text_count = matched_node_text.chars().count();
let curr_text = slice_by_chars(text, curr_idx, text_count);
let shared_count = shared_prefix_count(&matched_node_text, &curr_text);
if shared_count < matched_node_text_count {
/*
split the matched node
[curr] -> [matched_node] =>
becomes
[curr] -> [new_node] -> [contracted_matched_node]
*/
let matched_text = slice_by_chars(&matched_node_text, 0, shared_count);
let contracted_text = slice_by_chars(
&matched_node_text,
shared_count,
matched_node_text_count,
);
let matched_text_count = matched_text.chars().count();
let new_node = Arc::new(Node {
text: RwLock::new(matched_text),
children: DashMap::new(),
parent: RwLock::new(Some(Arc::clone(&curr))),
tenant_last_access_time: matched_node.tenant_last_access_time.clone(),
});
let first_new_char = contracted_text.chars().nth(0).unwrap();
new_node
.children
.insert(first_new_char, Arc::clone(&matched_node));
entry.insert(Arc::clone(&new_node));
*matched_node.text.write().unwrap() = contracted_text;
*matched_node.parent.write().unwrap() = Some(Arc::clone(&new_node));
prev = Arc::clone(&new_node);
// Increment char count for the tenant in the new split node
if !prev.tenant_last_access_time.contains_key(tenant) {
self.tenant_char_count
.entry(tenant.to_string())
.and_modify(|count| *count += matched_text_count)
.or_insert(matched_text_count);
}
prev.tenant_last_access_time
.insert(tenant.to_string(), timestamp_ms);
curr_idx += shared_count;
} else {
// move to next node
prev = Arc::clone(&matched_node);
// Increment char count when adding tenant to existing node
if !prev.tenant_last_access_time.contains_key(tenant) {
self.tenant_char_count
.entry(tenant.to_string())
.and_modify(|count| *count += matched_node_text_count)
.or_insert(matched_node_text_count);
}
prev.tenant_last_access_time
.insert(tenant.to_string(), timestamp_ms);
curr_idx += shared_count;
}
}
} else {
// create new child
let new_child = Node {
children: HashMap::new(),
ids: input_ids[curr_idx..].to_vec(),
count: 0,
};
let first_id = new_child.ids[0];
curr.children.insert(first_id, new_child);
curr = curr.children.get_mut(&first_id).unwrap();
curr.count += 1;
curr_idx = input_ids_len;
}
}
}
pub fn prefix_match<'a>(&self, input_ids: &'a [u32]) -> &'a [u32] {
let mut curr = &self.root;
pub fn prefix_match(&self, text: &str) -> (String, String) {
let mut curr = Arc::clone(&self.root);
let mut curr_idx = 0;
let input_ids_len = input_ids.len();
while curr_idx < input_ids_len {
match curr.children.get(&input_ids[curr_idx]) {
Some(child) => {
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);
let mut prev = Arc::clone(&self.root);
let text_count = text.chars().count();
while curr_idx < text_count {
let first_char = text.chars().nth(curr_idx).unwrap();
let curr_text = slice_by_chars(text, curr_idx, text_count);
curr = prev.clone();
if prefix_len == child.ids.len() {
curr_idx += prefix_len;
curr = child;
match curr.children.entry(first_char) {
Entry::Occupied(entry) => {
let matched_node = entry.get().clone();
let shared_count =
shared_prefix_count(&matched_node.text.read().unwrap(), &curr_text);
let matched_node_text_count = matched_node.text.read().unwrap().chars().count();
if shared_count == matched_node_text_count {
// Full match with current node's text, continue to next node
curr_idx += shared_count;
prev = Arc::clone(&matched_node);
} else {
curr_idx += prefix_len;
// Partial match, stop here
curr_idx += shared_count;
prev = Arc::clone(&matched_node);
break;
}
}
None => {
Entry::Vacant(_) => {
// No match found, stop here
break;
}
}
}
&input_ids[..curr_idx]
curr = prev.clone();
// Select the first tenant (key in the map)
let tenant = curr
.tenant_last_access_time
.iter()
.next()
.map(|kv| kv.key().to_owned())
.unwrap_or("empty".to_string());
// Traverse from the curr node to the root and update the timestamp
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis();
if !tenant.eq("empty") {
let mut current_node = Some(curr);
while let Some(node) = current_node {
node.tenant_last_access_time
.insert(tenant.clone(), timestamp_ms);
current_node = node.parent.read().unwrap().clone();
}
}
let ret_text = slice_by_chars(text, 0, curr_idx);
(ret_text, tenant)
}
pub fn delete(&mut self, input_ids: &[u32]) {
let mut curr = &mut self.root;
curr.count -= 1;
fn leaf_of(node: &NodeRef) -> Vec<String> {
/*
Return the list of tenants if it's a leaf for the tenant
*/
let mut candidates: HashMap<String, bool> = node
.tenant_last_access_time
.iter()
.map(|entry| (entry.key().clone(), true))
.collect();
let mut curr_idx = 0;
let input_ids_len = input_ids.len();
for child in node.children.iter() {
for tenant in child.value().tenant_last_access_time.iter() {
candidates.insert(tenant.key().clone(), false);
}
}
while curr_idx < input_ids_len {
let first_id = &input_ids[curr_idx];
candidates
.into_iter()
.filter(|(_, is_leaf)| *is_leaf)
.map(|(tenant, _)| tenant)
.collect()
}
if curr.children.contains_key(first_id) {
let child = curr.children.get(first_id).unwrap();
pub fn evict_tenant_data(&self, max_size: usize) {
// Calculate used size and collect leaves
let mut stack = vec![Arc::clone(&self.root)];
let mut used_size_per_tenant: HashMap<String, usize> = HashMap::new();
let mut pq = BinaryHeap::new();
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);
while let Some(curr) = stack.pop() {
for tenant in curr.tenant_last_access_time.iter() {
let size = used_size_per_tenant
.entry(tenant.key().clone())
.or_insert(0);
*size += curr.text.read().unwrap().chars().count();
}
if prefix_len == child.ids.len() {
if child.count == 1 {
// If count will become 0, remove the child
let child = curr.children.get_mut(first_id).unwrap();
child.count -= 1;
curr.children.remove(first_id);
break;
} else {
// Otherwise decrement count and continue
let child = curr.children.get_mut(first_id).unwrap();
for child in curr.children.iter() {
stack.push(Arc::clone(child.value()));
}
// Add leaves to priority queue
for tenant in Tree::leaf_of(&curr) {
if let Some(timestamp) = curr.tenant_last_access_time.get(&tenant) {
pq.push(Reverse(EvictionEntry {
timestamp: *timestamp,
tenant: tenant.clone(),
node: Arc::clone(&curr),
}));
}
}
}
println!("Before eviction - Used size per tenant:");
for (tenant, size) in &used_size_per_tenant {
println!("Tenant: {}, Size: {}", tenant, size);
}
child.count -= 1;
curr = child;
curr_idx += prefix_len;
// Process eviction
while let Some(Reverse(entry)) = pq.pop() {
let EvictionEntry { tenant, node, .. } = entry;
if let Some(&used_size) = used_size_per_tenant.get(&tenant) {
if used_size <= max_size {
continue;
}
// Update used size
if let Some(size) = used_size_per_tenant.get_mut(&tenant) {
*size -= node.text.read().unwrap().chars().count();
}
// Decrement when removing tenant from node
if node.tenant_last_access_time.contains_key(&tenant) {
self.tenant_char_count
.entry(tenant.clone())
.and_modify(|count| {
if *count > 0 {
*count -= node.text.read().unwrap().chars().count();
}
});
}
// Remove tenant from node
node.tenant_last_access_time.remove(&tenant);
// Remove empty nodes
if node.children.is_empty() && node.tenant_last_access_time.is_empty() {
if let Some(parent) = node.parent.write().unwrap().as_ref() {
let first_char = node.text.read().unwrap().chars().next().unwrap();
parent.children.remove(&first_char);
}
}
// Add parent to queue if it becomes a leaf
if let Some(parent) = node.parent.read().unwrap().as_ref() {
if Tree::leaf_of(parent).contains(&tenant) {
if let Some(timestamp) = parent.tenant_last_access_time.get(&tenant) {
pq.push(Reverse(EvictionEntry {
timestamp: *timestamp,
tenant: tenant.clone(),
node: Arc::clone(parent),
}));
}
}
} else {
panic!("No match found for {:?}", input_ids);
}
} else {
panic!("No match found for {:?}", input_ids);
}
}
println!("\nAfter eviction - Used size per tenant:");
for (tenant, size) in &used_size_per_tenant {
println!("Tenant: {}, Size: {}", tenant, size);
}
}
pub fn get_tenant_char_count(&self) -> HashMap<String, usize> {
self.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect()
}
pub fn get_smallest_tenant(&self) -> String {
// Return a placeholder if there are no tenants
if self.tenant_char_count.is_empty() {
return "empty".to_string();
}
// Find the tenant with minimum char count
let mut min_tenant = None;
let mut min_count = usize::MAX;
for entry in self.tenant_char_count.iter() {
let tenant = entry.key();
let count = *entry.value();
if count < min_count {
min_count = count;
min_tenant = Some(tenant.clone());
}
}
// Return the found tenant or "empty" if somehow none was found
min_tenant.unwrap_or_else(|| "empty".to_string())
}
pub fn get_used_size_per_tenant(&self) -> HashMap<String, usize> {
// perform a DFS to traverse all nodes and calculate the total size used by each tenant
let mut used_size_per_tenant: HashMap<String, usize> = HashMap::new();
let mut stack = vec![Arc::clone(&self.root)];
while let Some(curr) = stack.pop() {
let text_count = curr.text.read().unwrap().chars().count();
for tenant in curr.tenant_last_access_time.iter() {
let size = used_size_per_tenant
.entry(tenant.key().clone())
.or_insert(0);
*size += text_count;
}
for child in curr.children.iter() {
stack.push(Arc::clone(child.value()));
}
}
used_size_per_tenant
}
fn node_to_string(node: &NodeRef, prefix: &str, is_last: bool) -> String {
let mut result = String::new();
// Add prefix and branch character
result.push_str(prefix);
result.push_str(if is_last { "└── " } else { "├── " });
// Add node text
let node_text = node.text.read().unwrap();
result.push_str(&format!("'{}' [", node_text));
// Add tenant information with timestamps
let mut tenant_info = Vec::new();
for entry in node.tenant_last_access_time.iter() {
let tenant_id = entry.key();
let timestamp_ms = entry.value();
// Convert milliseconds to seconds and remaining milliseconds
let seconds = (timestamp_ms / 1000) as u64;
let millis = (timestamp_ms % 1000) as u32;
// Create SystemTime from Unix timestamp
let system_time = UNIX_EPOCH + Duration::from_secs(seconds);
// Format time as HH:MM:SS.mmm
let datetime = system_time.duration_since(UNIX_EPOCH).unwrap();
let hours = (datetime.as_secs() % 86400) / 3600;
let minutes = (datetime.as_secs() % 3600) / 60;
let seconds = datetime.as_secs() % 60;
tenant_info.push(format!(
"{} | {:02}:{:02}:{:02}.{:03}",
tenant_id, hours, minutes, seconds, millis
));
}
result.push_str(&tenant_info.join(", "));
result.push_str("]\n");
// Process children
let children: Vec<_> = node.children.iter().collect();
let child_count = children.len();
for (i, entry) in children.iter().enumerate() {
let is_last_child = i == child_count - 1;
let new_prefix = format!("{}{}", prefix, if is_last { " " } else { "│ " });
result.push_str(&Tree::node_to_string(
entry.value(),
&new_prefix,
is_last_child,
));
}
result
}
// for debug
pub fn pretty_print(&self) {
println!("RadixTree:");
Self::print_node(&self.root, String::from(""));
if self.root.children.is_empty() {
return;
}
let mut result = String::new();
let children: Vec<_> = self.root.children.iter().collect();
let child_count = children.len();
for (i, entry) in children.iter().enumerate() {
let is_last = i == child_count - 1;
result.push_str(&Tree::node_to_string(entry.value(), "", is_last));
}
println!("{result}");
return;
}
}
// Unit tests
#[cfg(test)]
mod tests {
use std::time::Instant;
use rand::Rng;
use super::*;
#[test]
fn test_get_smallest_tenant() {
let tree = Tree::new();
// Test empty tree
assert_eq!(tree.get_smallest_tenant(), "empty");
// Insert data for tenant1 - "ap" + "icot" = 6 chars
tree.insert("ap", "tenant1");
tree.insert("icot", "tenant1");
// Insert data for tenant2 - "cat" = 3 chars
tree.insert("cat", "tenant2");
// Test - tenant2 should be smallest with 3 chars vs 6 chars
assert_eq!(
tree.get_smallest_tenant(),
"tenant2",
"Expected tenant2 to be smallest with 3 characters"
);
// Insert overlapping data for tenant3 and tenant4 to test equal counts
// tenant3: "do" = 2 chars
// tenant4: "hi" = 2 chars
tree.insert("do", "tenant3");
tree.insert("hi", "tenant4");
// Test - should return either tenant3 or tenant4 (both have 2 chars)
let smallest = tree.get_smallest_tenant();
assert!(
smallest == "tenant3" || smallest == "tenant4",
"Expected either tenant3 or tenant4 (both have 2 characters), got {}",
smallest
);
// Add more text to tenant4 to make it larger
tree.insert("hello", "tenant4"); // Now tenant4 has "hi" + "hello" = 6 chars
// Now tenant3 should be smallest (2 chars vs 6 chars for tenant4)
assert_eq!(
tree.get_smallest_tenant(),
"tenant3",
"Expected tenant3 to be smallest with 2 characters"
);
// Test eviction
tree.evict_tenant_data(3); // This should evict tenants with more than 3 chars
let post_eviction_smallest = tree.get_smallest_tenant();
println!("Smallest tenant after eviction: {}", post_eviction_smallest);
}
#[test]
fn test_tenant_char_count() {
let tree = Tree::new();
// Phase 1: Initial insertions
tree.insert("apple", "tenant1");
tree.insert("apricot", "tenant1");
tree.insert("banana", "tenant1");
tree.insert("amplify", "tenant2");
tree.insert("application", "tenant2");
let computed_sizes = tree.get_used_size_per_tenant();
let maintained_counts: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
println!("Phase 1 - Maintained vs Computed counts:");
println!(
"Maintained: {:?}\nComputed: {:?}",
maintained_counts, computed_sizes
);
assert_eq!(
maintained_counts, computed_sizes,
"Phase 1: Initial insertions"
);
// Phase 2: Additional insertions
tree.insert("apartment", "tenant1");
tree.insert("appetite", "tenant2");
tree.insert("ball", "tenant1");
tree.insert("box", "tenant2");
let computed_sizes = tree.get_used_size_per_tenant();
let maintained_counts: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
println!("Phase 2 - Maintained vs Computed counts:");
println!(
"Maintained: {:?}\nComputed: {:?}",
maintained_counts, computed_sizes
);
assert_eq!(
maintained_counts, computed_sizes,
"Phase 2: Additional insertions"
);
// Phase 3: Overlapping insertions
tree.insert("zebra", "tenant1");
tree.insert("zebra", "tenant2");
tree.insert("zero", "tenant1");
tree.insert("zero", "tenant2");
let computed_sizes = tree.get_used_size_per_tenant();
let maintained_counts: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
println!("Phase 3 - Maintained vs Computed counts:");
println!(
"Maintained: {:?}\nComputed: {:?}",
maintained_counts, computed_sizes
);
assert_eq!(
maintained_counts, computed_sizes,
"Phase 3: Overlapping insertions"
);
// Phase 4: Eviction test
tree.evict_tenant_data(10);
let computed_sizes = tree.get_used_size_per_tenant();
let maintained_counts: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
println!("Phase 4 - Maintained vs Computed counts:");
println!(
"Maintained: {:?}\nComputed: {:?}",
maintained_counts, computed_sizes
);
assert_eq!(maintained_counts, computed_sizes, "Phase 4: After eviction");
}
fn random_string(len: usize) -> String {
Alphanumeric.sample_string(&mut thread_rng(), len)
}
#[test]
fn test_cold_start() {
let tree = Tree::new();
let (matched_text, tenant) = tree.prefix_match("hello");
assert_eq!(matched_text, "");
assert_eq!(tenant, "empty");
}
#[test]
fn test_exact_match_seq() {
let tree = Tree::new();
tree.insert("hello", "tenant1");
tree.pretty_print();
tree.insert("apple", "tenant2");
tree.pretty_print();
tree.insert("banana", "tenant3");
tree.pretty_print();
let (matched_text, tenant) = tree.prefix_match("hello");
assert_eq!(matched_text, "hello");
assert_eq!(tenant, "tenant1");
let (matched_text, tenant) = tree.prefix_match("apple");
assert_eq!(matched_text, "apple");
assert_eq!(tenant, "tenant2");
let (matched_text, tenant) = tree.prefix_match("banana");
assert_eq!(matched_text, "banana");
assert_eq!(tenant, "tenant3");
}
#[test]
fn test_exact_match_concurrent() {
let tree = Arc::new(Tree::new());
// spawn 3 threads for insert
let tree_clone = Arc::clone(&tree);
let texts = vec!["hello", "apple", "banana"];
let tenants = vec!["tenant1", "tenant2", "tenant3"];
let mut handles = vec![];
for i in 0..3 {
let tree_clone = Arc::clone(&tree_clone);
let text = texts[i];
let tenant = tenants[i];
let handle = thread::spawn(move || {
tree_clone.insert(text, tenant);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
// spawn 3 threads for match
let mut handles = vec![];
let tree_clone = Arc::clone(&tree);
for i in 0..3 {
let tree_clone = Arc::clone(&tree_clone);
let text = texts[i];
let tenant = tenants[i];
let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
assert_eq!(matched_text, text);
assert_eq!(matched_tenant, tenant);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_partial_match_concurrent() {
let tree = Arc::new(Tree::new());
// spawn 3 threads for insert
let tree_clone = Arc::clone(&tree);
let texts = vec!["apple", "apabc", "acbdeds"];
let mut handles = vec![];
for i in 0..3 {
let tree_clone = Arc::clone(&tree_clone);
let text = texts[i];
let tenant = "tenant0";
let handle = thread::spawn(move || {
tree_clone.insert(text, tenant);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
// spawn 3 threads for match
let mut handles = vec![];
let tree_clone = Arc::clone(&tree);
for i in 0..3 {
let tree_clone = Arc::clone(&tree_clone);
let text = texts[i];
let tenant = "tenant0";
let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
assert_eq!(matched_text, text);
assert_eq!(matched_tenant, tenant);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
}
fn print_node(node: &Node, prefix: String) {
// Print current node info with "count" word
println!("{}└── {:?} (count: {})", prefix, node.ids, node.count);
#[test]
fn test_group_prefix_insert_match_concurrent() {
let prefix = vec![
"Clock strikes midnight, I'm still wide awake",
"Got dreams bigger than these city lights",
"Time waits for no one, gotta make my move",
"Started from the bottom, that's no metaphor",
];
let suffix = vec![
"Got too much to prove, ain't got time to lose",
"History in the making, yeah, you can't erase this",
];
let tree = Arc::new(Tree::new());
// Print children with proper prefixes
for (i, child) in node.children.values().enumerate() {
let is_last = i == node.children.len() - 1;
let child_prefix = if is_last {
format!("{} ", prefix) // Add space for last child
} else {
format!("{}│ ", prefix) // Add vertical line for other children
};
Self::print_node(child, child_prefix);
let mut handles = vec![];
for i in 0..prefix.len() {
for j in 0..suffix.len() {
let tree_clone = Arc::clone(&tree);
let text = format!("{} {}", prefix[i], suffix[j]);
let tenant = format!("tenant{}", i);
let handle = thread::spawn(move || {
tree_clone.insert(&text, &tenant);
});
handles.push(handle);
}
}
// wait
for handle in handles {
handle.join().unwrap();
}
tree.pretty_print();
// check matching using multi threads
let mut handles = vec![];
for i in 0..prefix.len() {
let tree_clone = Arc::clone(&tree);
let text = prefix[i];
let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
let tenant = format!("tenant{}", i);
assert_eq!(matched_text, text);
assert_eq!(matched_tenant, tenant);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_mixed_concurrent_insert_match() {
// ensure it does not deadlock instead of doing correctness check
let prefix = vec![
"Clock strikes midnight, I'm still wide awake",
"Got dreams bigger than these city lights",
"Time waits for no one, gotta make my move",
"Started from the bottom, that's no metaphor",
];
let suffix = vec![
"Got too much to prove, ain't got time to lose",
"History in the making, yeah, you can't erase this",
];
let tree = Arc::new(Tree::new());
let mut handles = vec![];
for i in 0..prefix.len() {
for j in 0..suffix.len() {
let tree_clone = Arc::clone(&tree);
let text = format!("{} {}", prefix[i], suffix[j]);
let tenant = format!("tenant{}", i);
let handle = thread::spawn(move || {
tree_clone.insert(&text, &tenant);
});
handles.push(handle);
}
}
// check matching using multi threads
for i in 0..prefix.len() {
let tree_clone = Arc::clone(&tree);
let text = prefix[i];
let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_utf8_split_seq() {
// The string should be indexed and splitted by a utf-8 value basis instead of byte basis
// use .chars() to get the iterator of the utf-8 value
let tree = Arc::new(Tree::new());
let test_pairs = vec![
("你好嗎", "tenant1"),
("你好喔", "tenant2"),
("你心情好嗎", "tenant3"),
];
// Insert sequentially
for i in 0..test_pairs.len() {
let text = test_pairs[i].0;
let tenant = test_pairs[i].1;
tree.insert(text, tenant);
}
tree.pretty_print();
// Test sequentially
for i in 0..test_pairs.len() {
let (matched_text, matched_tenant) = tree.prefix_match(test_pairs[i].0);
assert_eq!(matched_text, test_pairs[i].0);
assert_eq!(matched_tenant, test_pairs[i].1);
}
}
#[test]
fn test_utf8_split_concurrent() {
let tree = Arc::new(Tree::new());
let test_pairs = vec![
("你好嗎", "tenant1"),
("你好喔", "tenant2"),
("你心情好嗎", "tenant3"),
];
// Create multiple threads for insertion
let mut handles = vec![];
for i in 0..test_pairs.len() {
let tree_clone = Arc::clone(&tree);
let text = test_pairs[i].0.to_string();
let tenant = test_pairs[i].1.to_string();
let handle = thread::spawn(move || {
tree_clone.insert(&text, &tenant);
});
handles.push(handle);
}
// Wait for all insertions to complete
for handle in handles {
handle.join().unwrap();
}
tree.pretty_print();
// Create multiple threads for matching
let mut handles = vec![];
for i in 0..test_pairs.len() {
let tree_clone = Arc::clone(&tree);
let text = test_pairs[i].0.to_string();
let tenant = test_pairs[i].1.to_string();
let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(&text);
assert_eq!(matched_text, text);
assert_eq!(matched_tenant, tenant);
});
handles.push(handle);
}
// Wait for all matches to complete
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_simple_eviction() {
let tree = Tree::new();
let max_size = 5;
// Insert strings for both tenants
tree.insert("hello", "tenant1"); // size 5
tree.insert("hello", "tenant2"); // size 5
thread::sleep(Duration::from_millis(10));
tree.insert("world", "tenant2"); // size 5, total for tenant2 = 10
tree.pretty_print();
// Verify initial sizes
let sizes_before = tree.get_used_size_per_tenant();
assert_eq!(sizes_before.get("tenant1").unwrap(), &5); // "hello" = 5
assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10
// Evict - should remove "hello" from tenant2 as it's the oldest
tree.evict_tenant_data(max_size);
tree.pretty_print();
// Verify sizes after eviction
let sizes_after = tree.get_used_size_per_tenant();
assert_eq!(sizes_after.get("tenant1").unwrap(), &5); // Should be unchanged
assert_eq!(sizes_after.get("tenant2").unwrap(), &5); // Only "world" remains
// Verify "world" remains for tenant2
let (matched, tenant) = tree.prefix_match("world");
assert_eq!(matched, "world");
assert_eq!(tenant, "tenant2");
}
#[test]
fn test_advanced_eviction() {
let tree = Tree::new();
// Set limits for each tenant
let max_size: usize = 100;
// Define prefixes
let prefixes = vec!["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"];
// Insert strings with shared prefixes
for i in 0..100 {
for (j, prefix) in prefixes.iter().enumerate() {
let random_suffix = random_string(10);
let text = format!("{}{}", prefix, random_suffix);
let tenant = format!("tenant{}", j + 1);
tree.insert(&text, &tenant);
}
}
// Perform eviction
tree.evict_tenant_data(max_size);
// Check sizes after eviction
let sizes_after = tree.get_used_size_per_tenant();
// Verify all tenants are under their size limits
for (tenant, &size) in sizes_after.iter() {
assert!(
size <= max_size,
"Tenant {} exceeds size limit. Current size: {}, Limit: {}",
tenant,
size,
max_size
);
}
}
#[test]
fn test_concurrent_operations_with_eviction() {
// Ensure eviction works fine with concurrent insert and match operations for a given period
let tree = Arc::new(Tree::new());
let mut handles = vec![];
let test_duration = Duration::from_secs(10);
let start_time = Instant::now();
let max_size = 100; // Single max size for all tenants
// Spawn eviction thread
{
let tree = Arc::clone(&tree);
let handle = thread::spawn(move || {
while start_time.elapsed() < test_duration {
// Run eviction
tree.evict_tenant_data(max_size);
// Sleep for 5 seconds
thread::sleep(Duration::from_secs(5));
}
});
handles.push(handle);
}
// Spawn 4 worker threads
for thread_id in 0..4 {
let tree = Arc::clone(&tree);
let handle = thread::spawn(move || {
let mut rng = rand::thread_rng();
let tenant = format!("tenant{}", thread_id + 1);
let prefix = format!("prefix{}", thread_id);
while start_time.elapsed() < test_duration {
// Random decision: match or insert (70% match, 30% insert)
if rng.gen_bool(0.7) {
// Perform match operation
let random_len = rng.gen_range(3..10);
let search_str = format!("{}{}", prefix, random_string(random_len));
let (matched, _) = tree.prefix_match(&search_str);
} else {
// Perform insert operation
let random_len = rng.gen_range(5..15);
let insert_str = format!("{}{}", prefix, random_string(random_len));
tree.insert(&insert_str, &tenant);
// println!("Thread {} inserted: {}", thread_id, insert_str);
}
// Small random sleep to vary timing
thread::sleep(Duration::from_millis(rng.gen_range(10..100)));
}
});
handles.push(handle);
}
// Wait for all threads to complete
for handle in handles {
handle.join().unwrap();
}
// final eviction
tree.evict_tenant_data(max_size);
// Final size check
let final_sizes = tree.get_used_size_per_tenant();
println!("Final sizes after test completion: {:?}", final_sizes);
// Verify all tenants are under limit
for (_, &size) in final_sizes.iter() {
assert!(
size <= max_size,
"Tenant exceeds size limit. Final size: {}, Limit: {}",
size,
max_size
);
}
}
#[test]
fn test_leaf_of() {
let tree = Tree::new();
// Single node
tree.insert("hello", "tenant1");
let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap());
assert_eq!(leaves, vec!["tenant1"]);
// Node with multiple tenants
tree.insert("hello", "tenant2");
let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap());
assert_eq!(leaves.len(), 2);
assert!(leaves.contains(&"tenant1".to_string()));
assert!(leaves.contains(&"tenant2".to_string()));
// Non-leaf node
tree.insert("hi", "tenant1");
let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap());
assert!(leaves.is_empty());
}
#[test]
fn test_get_used_size_per_tenant() {
let tree = Tree::new();
// Single tenant
tree.insert("hello", "tenant1");
tree.insert("world", "tenant1");
let sizes = tree.get_used_size_per_tenant();
tree.pretty_print();
println!("{:?}", sizes);
assert_eq!(sizes.get("tenant1").unwrap(), &10); // "hello" + "world"
// Multiple tenants sharing nodes
tree.insert("hello", "tenant2");
tree.insert("help", "tenant2");
let sizes = tree.get_used_size_per_tenant();
tree.pretty_print();
println!("{:?}", sizes);
assert_eq!(sizes.get("tenant1").unwrap(), &10);
assert_eq!(sizes.get("tenant2").unwrap(), &6); // "hello" + "p"
// UTF-8 characters
tree.insert("你好", "tenant3");
let sizes = tree.get_used_size_per_tenant();
tree.pretty_print();
println!("{:?}", sizes);
assert_eq!(sizes.get("tenant3").unwrap(), &2); // 2 Chinese characters
tree.pretty_print();
}
}
use sglang_router_rs::tree::RadixTree;
#[test]
fn test_new_tree() {
let tree = RadixTree::new();
assert_eq!(tree.root.count, 0);
assert!(tree.root.children.is_empty());
assert!(tree.root.ids.is_empty());
}
#[test]
fn test_single_insertion() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
assert_eq!(tree.root.count, 1);
assert_eq!(tree.root.children.len(), 1);
assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]);
assert_eq!(tree.root.children[&1].count, 1);
}
#[test]
fn test_multiple_insertions_no_split() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.insert(&[4, 5, 6]);
assert_eq!(tree.root.count, 2);
assert_eq!(tree.root.children.len(), 2);
assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]);
assert_eq!(tree.root.children[&4].ids, vec![4, 5, 6]);
}
#[test]
fn test_insertion_with_split() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
tree.insert(&[1, 2, 5, 6]);
assert_eq!(tree.root.count, 2);
assert_eq!(tree.root.children.len(), 1);
assert_eq!(tree.root.children[&1].ids, vec![1, 2]);
assert_eq!(tree.root.children[&1].children.len(), 2);
assert_eq!(tree.root.children[&1].children[&3].ids, vec![3, 4]);
assert_eq!(tree.root.children[&1].children[&5].ids, vec![5, 6]);
}
#[test]
fn test_prefix_match_exact() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
assert_eq!(tree.prefix_match(&[1, 2, 3, 4]), &[1, 2, 3, 4]);
}
#[test]
fn test_prefix_match_partial() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
assert_eq!(tree.prefix_match(&[1, 2, 3, 5]), &[1, 2, 3]);
assert_eq!(tree.prefix_match(&[1, 2, 5]), &[1, 2]);
assert_eq!(tree.prefix_match(&[1, 5]), &[1]);
}
#[test]
fn test_prefix_match_no_match() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
let empty_slices: &[u32] = &[];
assert_eq!(tree.prefix_match(&[5, 6, 7]), empty_slices);
}
#[test]
fn test_delete_leaf() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.delete(&[1, 2, 3]);
assert_eq!(tree.root.count, 0);
assert_eq!(tree.root.children.len(), 0);
}
#[test]
fn test_delete_with_siblings() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.insert(&[1, 2, 4]);
tree.delete(&[1, 2, 3]);
assert_eq!(tree.root.count, 1);
assert_eq!(tree.root.children[&1].children[&4].ids, vec![4]);
}
#[test]
fn test_multiple_operations() {
let mut tree = RadixTree::new();
// Insert several paths
tree.insert(&[1, 2, 3]);
tree.insert(&[1, 2, 4]);
tree.insert(&[1, 5, 6]);
// Verify structure
assert_eq!(tree.root.count, 3);
assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2, 3]);
assert_eq!(tree.prefix_match(&[1, 2, 4]), &[1, 2, 4]);
assert_eq!(tree.prefix_match(&[1, 5, 6]), &[1, 5, 6]);
// Delete and verify
tree.delete(&[1, 2, 3]);
assert_eq!(tree.root.count, 2);
assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2]); // Now only matches prefix
}
#[test]
#[should_panic(expected = "No match found")]
fn test_delete_nonexistent() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.delete(&[4, 5, 6]); // Should panic
}
#[test]
fn test_empty_input() {
let mut tree = RadixTree::new();
let empty_slice: &[u32] = &[];
tree.insert(empty_slice);
assert_eq!(tree.prefix_match(empty_slice), empty_slice);
tree.delete(empty_slice); // Should not panic
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment