Unverified Commit cbedd1db authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] cache-aware load-balancing router v1 (#2114)

parent ad47749b
import itertools import itertools
import json import json
import os
import random import random
import string import string
import threading import threading
import time import time
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path
from typing import Union
from tqdm import tqdm
import sglang as sgl import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenize from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
from sglang.utils import dump_state_text from sglang.utils import dump_state_text
random.seed(42)
def gen_prompt(tokenizer, token_num): def gen_prompt(tokenizer, token_num):
all_available_tokens = list(tokenizer.get_vocab().values()) all_available_tokens = list(tokenizer.get_vocab().values())
...@@ -24,12 +27,34 @@ def gen_prompt(tokenizer, token_num): ...@@ -24,12 +27,34 @@ def gen_prompt(tokenizer, token_num):
return ret return ret
def get_cache_path(args):
# Create cache directory under ~/.cache/sglang
cache_dir = Path.home() / ".cache" / "sglang"
# Create a unique cache filename based on the arguments that affect generation
cache_key = f"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json"
return cache_dir / cache_key
def gen_arguments(args, tokenizer): def gen_arguments(args, tokenizer):
multi_qas = [ cache_path = get_cache_path(args)
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
for _ in range(args.num_qa) # Try to load from cache first
] if cache_path.exists():
for i in range(args.num_qa): print(f"Loading cached arguments from {cache_path}")
with open(cache_path, "r") as f:
return json.load(f)
print("Generating new arguments...")
# First progress bar for system prompts
multi_qas = []
for _ in tqdm(range(args.num_qa), desc="Generating system prompts"):
multi_qas.append(
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
)
# Nested progress bars for QA pairs
for i in tqdm(range(args.num_qa), desc="Generating QA pairs"):
qas = multi_qas[i]["qas"] qas = multi_qas[i]["qas"]
for j in range(args.turns): for j in range(args.turns):
qas.append( qas.append(
...@@ -38,6 +63,13 @@ def gen_arguments(args, tokenizer): ...@@ -38,6 +63,13 @@ def gen_arguments(args, tokenizer):
"new_tokens": args.len_a, "new_tokens": args.len_a,
} }
) )
# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(cache_path, "w") as f:
json.dump(multi_qas, f)
print(f"Cached arguments saved to {cache_path}")
return multi_qas return multi_qas
...@@ -45,7 +77,7 @@ def gen_arguments(args, tokenizer): ...@@ -45,7 +77,7 @@ def gen_arguments(args, tokenizer):
def multi_turns(s, system_prompt, qas): def multi_turns(s, system_prompt, qas):
s += system_prompt s += system_prompt
for qa in qas: for i, qa in enumerate(qas):
s += qa["prompt"] s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True) s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
...@@ -62,7 +94,7 @@ def main(args): ...@@ -62,7 +94,7 @@ def main(args):
multi_qas, multi_qas,
temperature=0, temperature=0,
backend=backend, backend=backend,
num_threads=args.parallel, num_threads="auto",
progress_bar=True, progress_bar=True,
) )
latency = time.time() - tic latency = time.time() - tic
...@@ -75,7 +107,6 @@ def main(args): ...@@ -75,7 +107,6 @@ def main(args):
value = { value = {
"task": "multi_turn_system_prompt_chat", "task": "multi_turn_system_prompt_chat",
"backend": args.backend, "backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3), "latency": round(latency, 3),
"num_requests": args.num_qa, "num_requests": args.num_qa,
"num_turns": args.turns, "num_turns": args.turns,
......
...@@ -727,9 +727,9 @@ def sample_generated_shared_prefix_requests( ...@@ -727,9 +727,9 @@ def sample_generated_shared_prefix_requests(
total_input_tokens = 0 total_input_tokens = 0
total_output_tokens = 0 total_output_tokens = 0
for group_idx in range(num_groups): for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
system_prompt = system_prompts[group_idx] system_prompt = system_prompts[group_idx]
for prompt_idx in range(prompts_per_group): for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions"):
question = questions[group_idx * prompts_per_group + prompt_idx] question = questions[group_idx * prompts_per_group + prompt_idx]
full_prompt = f"{system_prompt}\n\n{question}" full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt)) prompt_len = len(tokenizer.encode(full_prompt))
......
...@@ -48,9 +48,13 @@ def run_eval(args): ...@@ -48,9 +48,13 @@ def run_eval(args):
# Select backend # Select backend
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}")) set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
# Read data if args.data_path is None:
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" # Read data
filename = download_and_cache_file(url) url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
else:
filename = args.data_path
lines = list(read_jsonl(filename)) lines = list(read_jsonl(filename))
# Construct prompts # Construct prompts
......
...@@ -591,6 +591,20 @@ dependencies = [ ...@@ -591,6 +591,20 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]] [[package]]
name = "deranged" name = "deranged"
version = "0.3.11" version = "0.3.11"
...@@ -904,6 +918,12 @@ dependencies = [ ...@@ -904,6 +918,12 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.1" version = "0.15.1"
...@@ -1226,7 +1246,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -1226,7 +1246,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown", "hashbrown 0.15.1",
] ]
[[package]] [[package]]
...@@ -2097,7 +2117,9 @@ dependencies = [ ...@@ -2097,7 +2117,9 @@ dependencies = [
"actix-web", "actix-web",
"bytes", "bytes",
"clap", "clap",
"dashmap",
"futures-util", "futures-util",
"http 1.1.0",
"pyo3", "pyo3",
"rand", "rand",
"reqwest", "reqwest",
......
...@@ -24,6 +24,8 @@ futures-util = "0.3" ...@@ -24,6 +24,8 @@ futures-util = "0.3"
serde_json = "1.0" serde_json = "1.0"
pyo3 = { version = "0.22.5", features = ["extension-module"] } pyo3 = { version = "0.22.5", features = ["extension-module"] }
tokenizers = { version = "0.20.3", features = ["http"] } tokenizers = { version = "0.20.3", features = ["http"] }
dashmap = "6.1.0"
http = "1.1.0"
[profile.release] [profile.release]
lto = "thin" lto = "thin"
......
...@@ -46,6 +46,9 @@ pip install <path-to-wheel> ...@@ -46,6 +46,9 @@ pip install <path-to-wheel>
#### Option B: Development Mode #### Option B: Development Mode
For development purposes, you can install the package in editable mode: For development purposes, you can install the package in editable mode:
Warning: Using editable python binding can suffer from performance degradation!! Please build a fresh wheel for every update if you want to test performance.
```bash ```bash
pip install -e . pip install -e .
``` ```
......
from sglang_router import PolicyType, Router
router = Router(
worker_urls=[
"http://localhost:30000",
"http://localhost:30001",
]
)
router.start()
import argparse
import os
import signal
import subprocess
import sys
import time
from typing import Dict, List
import requests
from sglang_router import PolicyType, Router
# Global processes list for cleanup
_processes: List[subprocess.Popen] = []
def cleanup_processes(signum=None, frame=None):
"""Cleanup function to kill all worker processes."""
print("\nCleaning up processes...")
for process in _processes:
try:
# Kill the entire process group
pgid = os.getpgid(process.pid)
os.killpg(pgid, signal.SIGKILL)
process.wait()
except:
pass
sys.exit(1)
# Register signal handlers
signal.signal(signal.SIGINT, cleanup_processes)
signal.signal(signal.SIGTERM, cleanup_processes)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Launch SGLang Router Server")
parser.add_argument(
"--host", type=str, default="localhost", help="Host address to bind the server"
)
parser.add_argument(
"--port", type=int, default=30000, help="Base port number for workers"
)
parser.add_argument(
"--dp",
type=int,
default=2,
help="Number of worker processes (degree of parallelism)",
)
parser.add_argument(
"--model-path", type=str, required=True, help="Path to the model"
)
parser.add_argument(
"--local-tokenizer-path",
type=str,
required=True,
help="Path to the local tokenizer",
)
return parser.parse_args()
def launch_workers(args) -> tuple[List[subprocess.Popen], List[str]]:
"""Launch all worker processes concurrently using subprocess."""
processes = []
worker_urls = []
# Launch each worker process
for i in range(args.dp):
port = args.port + i
url = f"http://{args.host}:{port}"
worker_urls.append(url)
# TODO: replace this with launch_server, and move this file to sglang/ because it depends on sglang
# We don't
command = f"export CUDA_VISIBLE_DEVICES={i}; python -m sglang.launch_server --model-path {args.model_path} --host {args.host} --port {port}"
print(command)
process = subprocess.Popen(command, shell=True)
processes.append(process)
_processes.append(process) # Add to global list for cleanup
return processes, worker_urls
def wait_for_healthy_workers(worker_urls: List[str], timeout: int = 300) -> bool:
"""Block until all workers are healthy or timeout is reached."""
start_time = time.time()
healthy_workers: Dict[str, bool] = {url: False for url in worker_urls}
while time.time() - start_time < timeout:
print("checking healthiness...")
all_healthy = True
for url in worker_urls:
if not healthy_workers[url]: # Only check workers that aren't healthy yet
try:
response = requests.get(f"{url}/health")
if response.status_code == 200:
print(f"Worker at {url} is healthy")
healthy_workers[url] = True
else:
all_healthy = False
except requests.RequestException:
all_healthy = False
if all_healthy:
print("All workers are healthy!")
return True
time.sleep(5)
# If we get here, we've timed out
unhealthy_workers = [url for url, healthy in healthy_workers.items() if not healthy]
print(f"Timeout waiting for workers: {unhealthy_workers}")
return False
def main():
"""Main function to launch the router and workers."""
args = parse_args()
processes = None
try:
# Launch all workers concurrently
processes, worker_urls = launch_workers(args)
# Block until all workers are healthy
if not wait_for_healthy_workers(worker_urls):
raise RuntimeError("Failed to start all workers")
# Initialize and start the router
router = Router(
worker_urls=worker_urls,
policy=PolicyType.ApproxTree,
tokenizer_path=args.local_tokenizer_path,
)
print("Starting router...")
router.start()
# Keep the main process running
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("\nShutting down...")
except Exception as e:
print(f"Error: {e}")
finally:
# Cleanup: Kill all worker processes
if processes:
for process in processes:
process.kill()
if __name__ == "__main__":
main()
import argparse
import dataclasses
import sys
from typing import List, Optional
from sglang_router import Router
from sglang_router_rs import PolicyType
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
worker_urls: List[str]
host: str = "127.0.0.1"
port: int = 30000
# Routing policy
policy: str = "cache_aware"
cache_threshold: float = 0.5
cache_routing_prob: float = 1.0
eviction_interval: int = 60
max_tree_size: int = 2**24
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
use_router_prefix: bool = False,
exclude_host_port: bool = False,
):
"""
Add router-specific arguments to an argument parser.
Args:
parser: The argument parser to add arguments to
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
"""
prefix = "router-" if use_router_prefix else ""
# Worker configuration
if not exclude_host_port:
parser.add_argument(
"--host",
type=str,
default=RouterArgs.host,
help="Host address to bind the router server",
)
parser.add_argument(
"--port",
type=int,
default=RouterArgs.port,
help="Port number to bind the router server",
)
parser.add_argument(
"--worker-urls",
type=str,
nargs="+",
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
)
# Routing policy configuration
parser.add_argument(
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware"],
help="Load balancing policy to use",
)
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
default=RouterArgs.cache_threshold,
help="Cache threshold (0.0-1.0) for cache-aware routing",
)
parser.add_argument(
f"--{prefix}cache-routing-prob",
type=float,
default=RouterArgs.cache_routing_prob,
help="Probability of using cache-aware routing (0.0-1.0)",
)
parser.add_argument(
f"--{prefix}eviction-interval",
type=int,
default=RouterArgs.eviction_interval,
help="Interval in seconds between cache eviction operations",
)
parser.add_argument(
f"--{prefix}max-tree-size",
type=int,
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
@classmethod
def from_cli_args(
cls, args: argparse.Namespace, use_router_prefix: bool = False
) -> "RouterArgs":
"""
Create RouterArgs instance from parsed command line arguments.
Args:
args: Parsed command line arguments
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
return cls(
worker_urls=args.worker_urls,
host=args.host,
port=args.port,
policy=getattr(args, f"{prefix}policy"),
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
cache_routing_prob=getattr(args, f"{prefix}cache_routing_prob"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
)
def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum."""
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
}
return policy_map[policy_str]
def launch_router(args: argparse.Namespace) -> Optional[Router]:
"""
Launch the SGLang router with the configuration from parsed arguments.
Args:
args: Namespace object containing router configuration
Can be either raw argparse.Namespace or converted RouterArgs
Returns:
Router instance if successful, None if failed
"""
try:
# Convert to RouterArgs if needed
if not isinstance(args, RouterArgs):
router_args = RouterArgs.from_cli_args(args)
else:
router_args = args
router = Router(
worker_urls=router_args.worker_urls,
policy=policy_from_str(router_args.policy),
host=router_args.host,
port=router_args.port,
cache_threshold=router_args.cache_threshold,
cache_routing_prob=router_args.cache_routing_prob,
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
)
router.start()
return router
except Exception as e:
print(f"Error starting router: {e}", file=sys.stderr)
return None
class CustomHelpFormatter(
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
):
"""Custom formatter that preserves both description formatting and shows defaults"""
pass
def parse_router_args(args: List[str]) -> RouterArgs:
"""Parse command line arguments and return RouterArgs instance."""
parser = argparse.ArgumentParser(
description="""SGLang Router - High-performance request distribution across worker nodes
Usage:
This launcher enables starting a router with individual worker instances. It is useful for
multi-node setups or when you want to start workers and router separately.
Examples:
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --cache-routing-prob 0.5
""",
formatter_class=CustomHelpFormatter,
)
RouterArgs.add_cli_args(parser, use_router_prefix=False)
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
def main() -> None:
router_args = parse_router_args(sys.argv[1:])
router = launch_router(router_args)
if router is None:
sys.exit(1)
if __name__ == "__main__":
main()
import argparse
import copy
import multiprocessing as mp
import os
import signal
import sys
import time
from typing import List
import requests
from sglang_router.launch_router import RouterArgs, launch_router
from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs, prepare_server_args
from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback
# Create new process group
def run_server(server_args, dp_rank):
os.setpgrp() # Create new process group
# Set DP_RANK environment variable
os.environ["DP_RANK"] = str(dp_rank)
launch_server(server_args)
def launch_server_process(
server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
"""Launch a single server process with the given args and port."""
server_args = copy.deepcopy(server_args)
server_args.port = worker_port
server_args.base_gpu_id = dp_id * server_args.tp_size
server_args.dp_size = 1
proc = mp.Process(target=run_server, args=(server_args, dp_id))
proc.start()
return proc
def cleanup_processes(processes: List[mp.Process]):
"""Clean up all processes using process groups."""
print("\nCleaning up processes...")
for proc in processes:
if proc.is_alive():
try:
# Kill the entire process group
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
# Give processes some time to terminate gracefully
proc.join(timeout=3)
# If process is still alive, force kill
if proc.is_alive():
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except ProcessLookupError:
pass # Process already terminated
def setup_signal_handlers(cleanup_func):
"""Setup handlers for various termination signals."""
def signal_handler(signum, frame):
cleanup_func()
sys.exit(1)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, "SIGQUIT"):
signal.signal(signal.SIGQUIT, signal_handler)
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint."""
start_time = time.time()
url = f"http://{host}:{port}/health"
while time.time() - start_time < timeout:
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
return True
except requests.exceptions.RequestException:
pass
time.sleep(1)
return False
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += 1
return available_ports
def main():
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn")
parser = argparse.ArgumentParser(
description="Launch SGLang router and server processes"
)
ServerArgs.add_cli_args(parser)
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
parser.add_argument(
"--router-dp-worker-base-port",
type=int,
default=31000,
help="Base port number for data parallel workers",
)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Find available ports for workers
worker_ports = find_available_ports(
args.router_dp_worker_base_port, server_args.dp_size
)
# Start server processes
server_processes = []
try:
# Launch server processes
for i, worker_port in enumerate(worker_ports):
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)
# Setup cleanup handler
setup_signal_handlers(lambda: cleanup_processes(server_processes))
# Wait for all servers to be healthy
all_healthy = True
for port in worker_ports:
if not wait_for_server_health(server_args.host, port):
print(f"Server on port {port} failed to become healthy")
all_healthy = False
break
if not all_healthy:
print("Not all servers are healthy. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
print("All servers are healthy. Starting router...")
# Update router args with worker URLs
router_args.worker_urls = [
f"http://{server_args.host}:{port}" for port in worker_ports
]
# Start the router
router = launch_router(router_args)
if router is None:
print("Failed to start router. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
except KeyboardInterrupt:
print("\nReceived shutdown signal...")
except Exception as e:
print(f"Error occurred: {e}")
print(get_exception_traceback())
finally:
cleanup_processes(server_processes)
if __name__ == "__main__":
main()
...@@ -9,16 +9,23 @@ class Router: ...@@ -9,16 +9,23 @@ class Router:
A high-performance router for distributing requests across worker nodes. A high-performance router for distributing requests across worker nodes.
Args: Args:
worker_urls: List of URLs for worker nodes that will handle requests worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
policy: Load balancing policy to use. Options: policy: Load balancing policy to use. Options:
- PolicyType.Random: Randomly select workers - PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion - PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.ApproxTree: Tree-based routing using tokenizer similarity - PolicyType.CacheAware: Distribute requests in cache-aware fashion
host: Host address to bind the router server host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server port: Port number to bind the router server. Default: 3001
tokenizer_path: Path to tokenizer model file (required for ApproxTree policy) cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
cache_threshold: Caching threshold value between 0-1 if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5
cache_routing_prob: Probability of using cache-aware routing (0.0-1.0). Default 1.0 for
full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven
workloads, use a lower value to better distribute requests
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
""" """
def __init__( def __init__(
...@@ -27,17 +34,20 @@ class Router: ...@@ -27,17 +34,20 @@ class Router:
policy: PolicyType = PolicyType.RoundRobin, policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 3001, port: int = 3001,
tokenizer_path: Optional[str] = None,
cache_threshold: float = 0.50, cache_threshold: float = 0.50,
cache_routing_prob: float = 1.0,
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
): ):
self._router = _Router( self._router = _Router(
worker_urls=worker_urls, worker_urls=worker_urls,
policy=policy, policy=policy,
host=host, host=host,
port=port, port=port,
tokenizer_path=tokenizer_path,
cache_threshold=cache_threshold, cache_threshold=cache_threshold,
cache_routing_prob=cache_routing_prob,
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
) )
def start(self) -> None: def start(self) -> None:
......
// Python Binding
use pyo3::prelude::*; use pyo3::prelude::*;
pub mod router; pub mod router;
mod server; pub mod server;
pub mod tree; pub mod tree;
#[pyclass(eq)] #[pyclass(eq)]
...@@ -9,7 +8,7 @@ pub mod tree; ...@@ -9,7 +8,7 @@ pub mod tree;
pub enum PolicyType { pub enum PolicyType {
Random, Random,
RoundRobin, RoundRobin,
ApproxTree, CacheAware,
} }
#[pyclass] #[pyclass]
...@@ -18,8 +17,10 @@ struct Router { ...@@ -18,8 +17,10 @@ struct Router {
port: u16, port: u16,
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: PolicyType, policy: PolicyType,
tokenizer_path: Option<String>, cache_threshold: f32,
cache_threshold: Option<f32>, cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
} }
#[pymethods] #[pymethods]
...@@ -30,33 +31,30 @@ impl Router { ...@@ -30,33 +31,30 @@ impl Router {
policy = PolicyType::RoundRobin, policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"), host = String::from("127.0.0.1"),
port = 3001, port = 3001,
tokenizer_path = None, cache_threshold = 0.50,
cache_threshold = Some(0.50) cache_routing_prob = 1.0,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24)
))] ))]
fn new( fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: PolicyType, policy: PolicyType,
host: String, host: String,
port: u16, port: u16,
tokenizer_path: Option<String>, cache_threshold: f32,
cache_threshold: Option<f32>, cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
) -> PyResult<Self> { ) -> PyResult<Self> {
// Validate required parameters for approx_tree policy
if matches!(policy, PolicyType::ApproxTree) {
if tokenizer_path.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"tokenizer_path is required for approx_tree policy",
));
}
}
Ok(Router { Ok(Router {
host, host,
port, port,
worker_urls, worker_urls,
policy, policy,
tokenizer_path,
cache_threshold, cache_threshold,
cache_routing_prob,
eviction_interval_secs,
max_tree_size,
}) })
} }
...@@ -68,14 +66,11 @@ impl Router { ...@@ -68,14 +66,11 @@ impl Router {
let policy_config = match &self.policy { let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig, PolicyType::Random => router::PolicyConfig::RandomConfig,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig, PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig { PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
tokenizer_path: self cache_threshold: self.cache_threshold,
.tokenizer_path cache_routing_prob: self.cache_routing_prob,
.clone() eviction_interval_secs: self.eviction_interval_secs,
.expect("tokenizer_path is required for approx_tree policy"), max_tree_size: self.max_tree_size,
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
}, },
}; };
......
// src/main.rs // src/main.rs
use clap::Parser; use clap::Parser;
use clap::ValueEnum; use clap::ValueEnum;
// declare child modules
mod router;
mod server;
mod tree;
use crate::router::PolicyConfig; use sglang_router_rs::{router::PolicyConfig, server};
#[derive(Debug, Clone, ValueEnum)] #[derive(Debug, Clone, ValueEnum)]
pub enum PolicyType { pub enum PolicyType {
Random, Random,
RoundRobin, RoundRobin,
ApproxTree, CacheAware,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
...@@ -21,44 +17,70 @@ struct Args { ...@@ -21,44 +17,70 @@ struct Args {
#[arg( #[arg(
long, long,
default_value = "127.0.0.1", default_value = "127.0.0.1",
help = "Host address to bind the server to" help = "Host address to bind the router server to. Default: 127.0.0.1"
)] )]
host: String, host: String,
#[arg(long, default_value_t = 3001, help = "Port number to listen on")] #[arg(
long,
default_value_t = 3001,
help = "Port number to bind the router server to. Default: 3001"
)]
port: u16, port: u16,
#[arg( #[arg(
long, long,
value_delimiter = ',', value_delimiter = ',',
help = "Comma-separated list of worker URLs to distribute requests to" help = "Comma-separated list of worker URLs that will handle the requests. Each URL should include the protocol, host, and port (e.g., http://worker1:8000,http://worker2:8000)"
)] )]
worker_urls: Vec<String>, worker_urls: Vec<String>,
#[arg( #[arg(
long, long,
default_value_t = PolicyType::RoundRobin, default_value_t = PolicyType::CacheAware,
value_enum, value_enum,
help = "Load balancing policy to use: random, round_robin, or approx_tree" help = "Load balancing policy to use for request distribution:\n\
- random: Randomly select workers\n\
- round_robin: Distribute requests in round-robin fashion\n\
- cache_aware: Distribute requests in cache-aware fashion\n"
)] )]
policy: PolicyType, policy: PolicyType,
#[arg( #[arg(
long, long,
default_value_t = 0.5,
requires = "policy",
required_if_eq("policy", "cache_aware"),
help = "Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5"
)]
cache_threshold: f32,
#[arg(
long,
default_value_t = 1.0,
requires = "policy",
required_if_eq("policy", "cache_aware"),
help = "Probability of using cache-aware routing (0.0-1.0). Default 1.0 for full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven workloads, use a lower value to better distribute requests"
)]
cache_routing_prob: f32,
#[arg(
long,
default_value_t = 60,
requires = "policy", requires = "policy",
required_if_eq("policy", "approx_tree"), required_if_eq("policy", "cache_aware"),
help = "Path to the tokenizer file, required when using approx_tree policy" help = "Interval in seconds between cache eviction operations in cache-aware routing. Default: 60"
)] )]
tokenizer_path: Option<String>, eviction_interval_secs: u64,
#[arg( #[arg(
long, long,
default_value = "0.50", default_value_t = 2usize.pow(24),
requires = "policy", requires = "policy",
required_if_eq("policy", "approx_tree"), required_if_eq("policy", "cache_aware"),
help = "Cache threshold (0.0-1.0) for approx_tree routing. Routes to cached worker if match rate exceeds threshold, otherwise routes to shortest queue worker" help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
)] )]
cache_threshold: Option<f32>, max_tree_size: usize,
} }
impl Args { impl Args {
...@@ -66,14 +88,11 @@ impl Args { ...@@ -66,14 +88,11 @@ impl Args {
match self.policy { match self.policy {
PolicyType::Random => PolicyConfig::RandomConfig, PolicyType::Random => PolicyConfig::RandomConfig,
PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig, PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => PolicyConfig::ApproxTreeConfig { PolicyType::CacheAware => PolicyConfig::CacheAwareConfig {
tokenizer_path: self cache_threshold: self.cache_threshold,
.tokenizer_path cache_routing_prob: self.cache_routing_prob,
.clone() eviction_interval_secs: self.eviction_interval_secs,
.expect("tokenizer_path is required for approx_tree policy"), max_tree_size: self.max_tree_size,
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
}, },
} }
} }
......
use crate::tree::RadixTree; use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes; use bytes::Bytes;
use futures_util::TryStreamExt; use futures_util::{Stream, StreamExt, TryStreamExt};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use tokenizers::tokenizer::Tokenizer; use std::thread;
use std::time::Duration;
#[derive(Debug)] #[derive(Debug)]
pub enum Router { pub enum Router {
...@@ -18,34 +21,88 @@ pub enum Router { ...@@ -18,34 +21,88 @@ pub enum Router {
Random { Random {
worker_urls: Vec<String>, worker_urls: Vec<String>,
}, },
ApproxTree { CacheAware {
/*
Cache-Aware Load Balancing Router
This router combines two strategies to optimize both cache utilization and request distribution:
1. Cache-Aware Routing (Approximate Tree)
2. Load Balancing (Shortest Queue)
For each incoming request, the router chooses between these strategies:
- With probability P: Uses cache-aware routing
- With probability (1-P): Uses load balancing
where P is configured via `cache_routing_prob`
Strategy Details:
1. Cache-Aware Routing (Approximate Tree)
-------------------------------------------
This strategy maintains an approximate radix tree for each worker based on request history,
eliminating the need for direct cache state queries. The tree stores raw text characters
instead of token IDs to avoid tokenization overhead.
Process:
a. For each request, find the worker with the highest prefix match
b. If match rate > cache_threshold:
Route to the worker with highest match (likely has relevant data cached)
c. If match rate ≤ cache_threshold:
Route to the worker with smallest tree size (most available cache capacity)
d. Background maintenance:
Periodically evict least recently used leaf nodes to prevent memory overflow
2. Load Balancing (Shortest Queue)
-------------------------------------------
This strategy tracks pending request counts per worker and routes new requests
to the least busy worker for optimal load distribution.
Configuration Parameters:
------------------------
1. cache_routing_prob: (float, 0.0 to 1.0)
- 0.0: Exclusively use load balancing
- 1.0: Exclusively use cache-aware routing
- Between 0-1: Probability of using cache-aware routing vs load balancing
2. cache_threshold: (float, 0.0 to 1.0)
Minimum prefix match ratio to use highest-match routing.
Below this threshold, routes to worker with most available cache space.
3. eviction_interval_secs: (integer)
Interval between LRU eviction cycles for the approximate trees.
4. max_tree_size: (integer)
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Vec<String>, worker_urls: Vec<String>,
// TODO: don't lock the whole tree tree: Arc<Mutex<Tree>>,
url_to_tree: Arc<Mutex<HashMap<String, RadixTree>>>, running_queue: Arc<Mutex<HashMap<String, usize>>>,
tokenizer: Tokenizer, processed_queue: Arc<Mutex<HashMap<String, usize>>>,
url_to_count: Arc<Mutex<HashMap<String, usize>>>,
cache_threshold: f32, cache_threshold: f32,
cache_routing_prob: f32,
_eviction_thread: Option<thread::JoinHandle<()>>, // Store thread handle
}, },
} }
#[derive(Debug)]
pub enum PolicyConfig { pub enum PolicyConfig {
RandomConfig, RandomConfig,
RoundRobinConfig, RoundRobinConfig,
ApproxTreeConfig { CacheAwareConfig {
tokenizer_path: String,
cache_threshold: f32, cache_threshold: f32,
cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
}, },
} }
fn get_token_ids_from_request(body: &Bytes, tokenizer: &Tokenizer) -> Vec<u32> { fn get_text_from_request(body: &Bytes) -> String {
// 1. convert body to json // 1. convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap(); let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
// 2. get the text field // 2. get the text field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or(""); let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
// 3. tokenize the text field return text.to_string();
let tokens = tokenizer.encode(text, false).unwrap();
tokens.get_ids().to_vec()
} }
impl Router { impl Router {
...@@ -56,25 +113,56 @@ impl Router { ...@@ -56,25 +113,56 @@ impl Router {
worker_urls, worker_urls,
current_index: std::sync::atomic::AtomicUsize::new(0), current_index: std::sync::atomic::AtomicUsize::new(0),
}, },
PolicyConfig::ApproxTreeConfig { PolicyConfig::CacheAwareConfig {
tokenizer_path,
cache_threshold, cache_threshold,
cache_routing_prob,
eviction_interval_secs,
max_tree_size,
} => { } => {
let mut url_to_tree = HashMap::new(); let mut running_queue = HashMap::new();
let mut url_to_count = HashMap::new(); for url in &worker_urls {
running_queue.insert(url.clone(), 0);
}
let mut processed_queue = HashMap::new();
for url in &worker_urls {
processed_queue.insert(url.clone(), 0);
}
let tree = Arc::new(Mutex::new(Tree::new()));
let running_queue = Arc::new(Mutex::new(running_queue));
let processed_queue = Arc::new(Mutex::new(processed_queue));
// Create background eviction thread
let tree_clone = Arc::clone(&tree);
let processed_queue_clone = Arc::clone(&processed_queue);
let eviction_thread = thread::spawn(move || {
loop {
// Sleep for the specified interval
thread::sleep(Duration::from_secs(eviction_interval_secs));
let locked_tree_clone = tree_clone.lock().unwrap();
// Run eviction
locked_tree_clone.evict_tenant_data(max_tree_size);
// Print the process queue
let locked_processed_queue = processed_queue_clone.lock().unwrap();
println!("Processed Queue: {:?}", locked_processed_queue);
}
});
for url in &worker_urls { for url in &worker_urls {
url_to_tree.insert(url.clone(), RadixTree::new()); tree.lock().unwrap().insert(&"".to_string(), url);
url_to_count.insert(url.clone(), 0);
} }
Router::ApproxTree { Router::CacheAware {
worker_urls, worker_urls,
url_to_tree: Arc::new(Mutex::new(url_to_tree)), tree,
// TODO: rust ::from_pretrained cannot load from local file, so use ::from_file to load local file running_queue,
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(), processed_queue,
url_to_count: Arc::new(Mutex::new(url_to_count)),
cache_threshold, cache_threshold,
cache_routing_prob,
_eviction_thread: Some(eviction_thread),
} }
} }
} }
...@@ -84,7 +172,7 @@ impl Router { ...@@ -84,7 +172,7 @@ impl Router {
match self { match self {
Router::RoundRobin { worker_urls, .. } Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls } | Router::Random { worker_urls }
| Router::ApproxTree { worker_urls, .. } => { | Router::CacheAware { worker_urls, .. } => {
if worker_urls.is_empty() { if worker_urls.is_empty() {
None None
} else { } else {
...@@ -100,10 +188,7 @@ impl Router { ...@@ -100,10 +188,7 @@ impl Router {
req: HttpRequest, req: HttpRequest,
body: Bytes, body: Bytes,
) -> HttpResponse { ) -> HttpResponse {
let mut input_ids: Vec<u32> = Vec::new(); let text = get_text_from_request(&body);
if let Router::ApproxTree { tokenizer, .. } = self {
input_ids = get_token_ids_from_request(&body, tokenizer);
}
let worker_url = match self { let worker_url = match self {
Router::RoundRobin { Router::RoundRobin {
...@@ -125,78 +210,73 @@ impl Router { ...@@ -125,78 +210,73 @@ impl Router {
worker_urls[rand::random::<usize>() % worker_urls.len()].clone() worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
} }
Router::ApproxTree { Router::CacheAware {
worker_urls, worker_urls,
url_to_tree, tree,
url_to_count, running_queue,
processed_queue,
cache_threshold, cache_threshold,
cache_routing_prob,
.. ..
} => { } => {
// TODO: pipeline the locks. Release one earlier. // even though the tree is thread-safe, we still put a lock to ensure the whole op (tree read + queue read + tree write + queue write) is atomic to handle some edge cases (e.g. multiple requests with long prefix entering at the same time)
let mut max_matched_rate = 0.0; let mut tree = tree.lock().unwrap();
let mut max_matched_idx = 0; let mut running_queue = running_queue.lock().unwrap();
let locked_url_to_tree = url_to_tree.lock().unwrap(); // Generate a random float between 0 and 1 for probability check
let sampled_p: f32 = rand::random();
// 1. Find the highest matched worker let selected_url = if sampled_p < *cache_routing_prob {
for (i, url) in worker_urls.iter().enumerate() { // Cache-aware routing logic
let tree = locked_url_to_tree.get(url).unwrap(); let (matched_text, matched_worker) = tree.prefix_match(&text);
let matched = tree.prefix_match(&input_ids[..]).len(); let matched_rate =
let matched_rate = matched as f32 / input_ids.len() as f32; matched_text.chars().count() as f32 / text.chars().count() as f32;
if matched_rate > max_matched_rate { if matched_rate > *cache_threshold {
max_matched_rate = matched_rate; matched_worker.to_string()
max_matched_idx = i; } else {
} let m_map: HashMap<String, usize> = tree
} .tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
// 2. If the rate is higher than the threshold, select the worker. If not, select the worker with the shortest queue println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
if max_matched_rate > *cache_threshold {
worker_urls[max_matched_idx].clone() tree.get_smallest_tenant()
} else {
// pick the shortest queue from url_to_count
let locked_url_to_count = url_to_count.lock().unwrap();
let mut min_count = std::usize::MAX;
let mut min_count_id = 0;
for (i, url) in worker_urls.iter().enumerate() {
let count = locked_url_to_count.get(url).unwrap();
if *count < min_count {
min_count = *count;
min_count_id = i;
}
} }
} else {
// Shortest queue routing logic
running_queue
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone())
};
worker_urls[min_count_id].clone() // Update running queue
} let count = running_queue.get_mut(&selected_url).unwrap();
*count += 1;
// Update processed queue
let mut locked_processed_queue = processed_queue.lock().unwrap();
let count = locked_processed_queue.get_mut(&selected_url).unwrap();
*count += 1;
// Update tree with the new request
tree.insert(&text, &selected_url);
selected_url
} }
}; };
if let Router::ApproxTree {
url_to_tree,
url_to_count,
..
} = self
{
// Insert input_ids to the tree
let mut locked_url_to_tree = url_to_tree.lock().unwrap();
let selected_tree = locked_url_to_tree.get_mut(&worker_url).unwrap();
selected_tree.insert(&input_ids[..]);
let mut locked_url_to_count = url_to_count.lock().unwrap();
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
*count += 1;
}
// Check if client requested streaming
let is_stream = serde_json::from_slice::<serde_json::Value>(&body) let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false); .unwrap_or(false);
let res = match client let res = match client
.post(format!("{}/generate", worker_url)) .post(format!("{}/generate", worker_url.clone()))
.header( .header(
"Content-Type", "Content-Type",
req.headers() req.headers()
...@@ -216,23 +296,53 @@ impl Router { ...@@ -216,23 +296,53 @@ impl Router {
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream { if !is_stream {
// TODO: do the correction on the tree based on the cached input_ids // For non-streaming requests, get response first
if let Router::ApproxTree { url_to_count, .. } = self { let response = match res.bytes().await {
let mut locked_url_to_count = url_to_count.lock().unwrap();
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
*count -= 1;
}
match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()), Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(), Err(_) => HttpResponse::InternalServerError().finish(),
};
// Then decrement running queue counter if using CacheAware
if let Router::CacheAware { running_queue, .. } = self {
if let Ok(mut queue) = running_queue.lock() {
if let Some(count) = queue.get_mut(&worker_url) {
*count = count.saturating_sub(1);
}
}
} }
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
let worker_url = worker_url.clone();
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(
res.bytes_stream()
.map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read stream")
})
.inspect(move |bytes| {
let bytes = bytes.as_ref().unwrap();
if bytes
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]")
{
let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1);
// print
// println!("streaming is done!!")
}
}),
)
} else { } else {
// TODO: do the correction on the tree based on the cached input_ids. The streaming might be tricker to handle
HttpResponse::build(status) HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(res.bytes_stream().map_err(|_| { .streaming(res.bytes_stream().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read string") actix_web::error::ErrorInternalServerError("Failed to read stream")
})) }))
} }
} }
......
...@@ -76,6 +76,7 @@ pub async fn startup( ...@@ -76,6 +76,7 @@ pub async fn startup(
) -> std::io::Result<()> { ) -> std::io::Result<()> {
println!("Starting server on {}:{}", host, port); println!("Starting server on {}:{}", host, port);
println!("Worker URLs: {:?}", worker_urls); println!("Worker URLs: {:?}", worker_urls);
println!("Policy Config: {:?}", policy_config);
// Create client once with configuration // Create client once with configuration
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
......
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use rand::distributions::{Alphanumeric, DistString};
use rand::thread_rng;
use std::cmp::min;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::HashMap; use std::collections::HashMap;
use std::mem; use std::sync::Arc;
use std::sync::RwLock;
use std::thread;
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
type NodeRef = Arc<Node>;
#[derive(Debug)] #[derive(Debug)]
pub struct Node { struct Node {
pub children: HashMap<u32, Node>, // the key is first id of the child because each child must have unique first id children: DashMap<char, NodeRef>,
pub ids: Vec<u32>, text: RwLock<String>,
pub count: u32, tenant_last_access_time: DashMap<String, u128>,
parent: RwLock<Option<NodeRef>>,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct RadixTree { pub struct Tree {
pub root: Node, root: NodeRef,
// TODO: Char Count per tenant
pub tenant_char_count: DashMap<String, usize>,
} }
fn common_prefix_len(a: &[u32], b: &[u32]) -> usize { // For the heap
let mut i = 0;
while i < a.len() && i < b.len() && a[i] == b[i] { struct EvictionEntry {
i += 1; timestamp: u128,
tenant: String,
node: NodeRef,
}
impl Eq for EvictionEntry {}
impl PartialOrd for EvictionEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.timestamp.cmp(&other.timestamp))
} }
i
} }
impl Default for RadixTree { impl Ord for EvictionEntry {
fn default() -> Self { fn cmp(&self, other: &Self) -> std::cmp::Ordering {
Self::new() self.timestamp.cmp(&other.timestamp)
} }
} }
impl RadixTree { impl PartialEq for EvictionEntry {
fn eq(&self, other: &Self) -> bool {
self.timestamp == other.timestamp
}
}
// For char operations
// Note that in rust, `.len()` or slice is operated on the "byte" level. It causes issues for UTF-8 characters because one character might use multiple bytes.
// https://en.wikipedia.org/wiki/UTF-8
fn shared_prefix_count(a: &str, b: &str) -> usize {
let mut i = 0;
let mut a_iter = a.chars();
let mut b_iter = b.chars();
loop {
match (a_iter.next(), b_iter.next()) {
(Some(a_char), Some(b_char)) if a_char == b_char => {
i += 1;
}
_ => break,
}
}
return i;
}
fn slice_by_chars(s: &str, start: usize, end: usize) -> String {
s.chars().skip(start).take(end - start).collect()
}
impl Tree {
/*
Thread-safe multi tenant radix tree
1. Storing data for multiple tenants (the overlap of multiple radix tree)
2. Node-level lock to enable concurrent acesss on nodes
3. Leaf LRU eviction based on tenant access time
*/
pub fn new() -> Self { pub fn new() -> Self {
RadixTree { Tree {
root: Node { root: Arc::new(Node {
children: HashMap::new(), children: DashMap::new(),
ids: Vec::new(), text: RwLock::new("".to_string()),
count: 0, tenant_last_access_time: DashMap::new(),
}, parent: RwLock::new(None),
}),
tenant_char_count: DashMap::new(),
} }
} }
pub fn insert(&mut self, input_ids: &[u32]) { pub fn insert(&self, text: &str, tenant: &str) {
let mut curr = &mut self.root; // Insert text into tree with given tenant
curr.count += 1;
let mut curr = Arc::clone(&self.root);
let mut curr_idx = 0; let mut curr_idx = 0;
let input_ids_len = input_ids.len();
let timestamp_ms = SystemTime::now()
while curr_idx < input_ids_len { .duration_since(UNIX_EPOCH)
let first_id = &input_ids[curr_idx]; .unwrap()
// TODO: changing this get_mut causes error .as_millis();
if curr.children.contains_key(first_id) {
let child = curr.children.get_mut(first_id).unwrap(); curr.tenant_last_access_time
.insert(tenant.to_string(), timestamp_ms);
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);
self.tenant_char_count
if prefix_len == child.ids.len() { .entry(tenant.to_string())
// move curr to child .or_insert(0);
curr = child;
curr.count += 1; let mut prev = Arc::clone(&self.root);
curr_idx += prefix_len;
} else { let text_count = text.chars().count();
// split child
// [child]->... => [child]->[new child]->... while curr_idx < text_count {
let new_child = Node { let first_char = text.chars().nth(curr_idx).unwrap();
// to avoid clone: replace child.children with default value (empty vector) and return the original value
children: mem::take(&mut child.children), curr = prev;
ids: child.ids[prefix_len..].to_vec(),
count: child.count, // dashmap.entry locks the entry until the op is done
}; // if using contains_key + insert, there will be an issue that
// 1. "apple" and "app" entered at the same time
child.ids = child.ids[..prefix_len].to_vec(); // 2. and get inserted to the dashmap concurrently, so only one is inserted
child.children = HashMap::new();
child.children.insert(new_child.ids[0], new_child); match curr.children.entry(first_char) {
Entry::Vacant(entry) => {
curr = child; /*
curr.count += 1; no matched
curr_idx += prefix_len; [curr]
becomes
[curr] => [new node]
*/
let curr_text = slice_by_chars(text, curr_idx, text_count);
let curr_text_count = curr_text.chars().count();
let new_node = Arc::new(Node {
children: DashMap::new(),
text: RwLock::new(curr_text),
tenant_last_access_time: DashMap::new(),
parent: RwLock::new(Some(Arc::clone(&curr))),
});
// Increment char count when creating new node with tenant
self.tenant_char_count
.entry(tenant.to_string())
.and_modify(|count| *count += curr_text_count)
.or_insert(curr_text_count);
new_node
.tenant_last_access_time
.insert(tenant.to_string(), timestamp_ms);
entry.insert(Arc::clone(&new_node));
prev = Arc::clone(&new_node);
curr_idx = text_count;
}
Entry::Occupied(mut entry) => {
// matched
let matched_node = entry.get().clone();
let matched_node_text = matched_node.text.read().unwrap().to_owned();
let matched_node_text_count = matched_node_text.chars().count();
let curr_text = slice_by_chars(text, curr_idx, text_count);
let shared_count = shared_prefix_count(&matched_node_text, &curr_text);
if shared_count < matched_node_text_count {
/*
split the matched node
[curr] -> [matched_node] =>
becomes
[curr] -> [new_node] -> [contracted_matched_node]
*/
let matched_text = slice_by_chars(&matched_node_text, 0, shared_count);
let contracted_text = slice_by_chars(
&matched_node_text,
shared_count,
matched_node_text_count,
);
let matched_text_count = matched_text.chars().count();
let new_node = Arc::new(Node {
text: RwLock::new(matched_text),
children: DashMap::new(),
parent: RwLock::new(Some(Arc::clone(&curr))),
tenant_last_access_time: matched_node.tenant_last_access_time.clone(),
});
let first_new_char = contracted_text.chars().nth(0).unwrap();
new_node
.children
.insert(first_new_char, Arc::clone(&matched_node));
entry.insert(Arc::clone(&new_node));
*matched_node.text.write().unwrap() = contracted_text;
*matched_node.parent.write().unwrap() = Some(Arc::clone(&new_node));
prev = Arc::clone(&new_node);
// Increment char count for the tenant in the new split node
if !prev.tenant_last_access_time.contains_key(tenant) {
self.tenant_char_count
.entry(tenant.to_string())
.and_modify(|count| *count += matched_text_count)
.or_insert(matched_text_count);
}
prev.tenant_last_access_time
.insert(tenant.to_string(), timestamp_ms);
curr_idx += shared_count;
} else {
// move to next node
prev = Arc::clone(&matched_node);
// Increment char count when adding tenant to existing node
if !prev.tenant_last_access_time.contains_key(tenant) {
self.tenant_char_count
.entry(tenant.to_string())
.and_modify(|count| *count += matched_node_text_count)
.or_insert(matched_node_text_count);
}
prev.tenant_last_access_time
.insert(tenant.to_string(), timestamp_ms);
curr_idx += shared_count;
}
} }
} else {
// create new child
let new_child = Node {
children: HashMap::new(),
ids: input_ids[curr_idx..].to_vec(),
count: 0,
};
let first_id = new_child.ids[0];
curr.children.insert(first_id, new_child);
curr = curr.children.get_mut(&first_id).unwrap();
curr.count += 1;
curr_idx = input_ids_len;
} }
} }
} }
pub fn prefix_match<'a>(&self, input_ids: &'a [u32]) -> &'a [u32] { pub fn prefix_match(&self, text: &str) -> (String, String) {
let mut curr = &self.root; let mut curr = Arc::clone(&self.root);
let mut curr_idx = 0; let mut curr_idx = 0;
let input_ids_len = input_ids.len();
while curr_idx < input_ids_len { let mut prev = Arc::clone(&self.root);
match curr.children.get(&input_ids[curr_idx]) { let text_count = text.chars().count();
Some(child) => {
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids); while curr_idx < text_count {
let first_char = text.chars().nth(curr_idx).unwrap();
let curr_text = slice_by_chars(text, curr_idx, text_count);
curr = prev.clone();
if prefix_len == child.ids.len() { match curr.children.entry(first_char) {
curr_idx += prefix_len; Entry::Occupied(entry) => {
curr = child; let matched_node = entry.get().clone();
let shared_count =
shared_prefix_count(&matched_node.text.read().unwrap(), &curr_text);
let matched_node_text_count = matched_node.text.read().unwrap().chars().count();
if shared_count == matched_node_text_count {
// Full match with current node's text, continue to next node
curr_idx += shared_count;
prev = Arc::clone(&matched_node);
} else { } else {
curr_idx += prefix_len; // Partial match, stop here
curr_idx += shared_count;
prev = Arc::clone(&matched_node);
break; break;
} }
} }
None => { Entry::Vacant(_) => {
// No match found, stop here
break; break;
} }
} }
} }
&input_ids[..curr_idx] curr = prev.clone();
// Select the first tenant (key in the map)
let tenant = curr
.tenant_last_access_time
.iter()
.next()
.map(|kv| kv.key().to_owned())
.unwrap_or("empty".to_string());
// Traverse from the curr node to the root and update the timestamp
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis();
if !tenant.eq("empty") {
let mut current_node = Some(curr);
while let Some(node) = current_node {
node.tenant_last_access_time
.insert(tenant.clone(), timestamp_ms);
current_node = node.parent.read().unwrap().clone();
}
}
let ret_text = slice_by_chars(text, 0, curr_idx);
(ret_text, tenant)
} }
pub fn delete(&mut self, input_ids: &[u32]) { fn leaf_of(node: &NodeRef) -> Vec<String> {
let mut curr = &mut self.root; /*
curr.count -= 1; Return the list of tenants if it's a leaf for the tenant
*/
let mut candidates: HashMap<String, bool> = node
.tenant_last_access_time
.iter()
.map(|entry| (entry.key().clone(), true))
.collect();
let mut curr_idx = 0; for child in node.children.iter() {
let input_ids_len = input_ids.len(); for tenant in child.value().tenant_last_access_time.iter() {
candidates.insert(tenant.key().clone(), false);
}
}
while curr_idx < input_ids_len { candidates
let first_id = &input_ids[curr_idx]; .into_iter()
.filter(|(_, is_leaf)| *is_leaf)
.map(|(tenant, _)| tenant)
.collect()
}
if curr.children.contains_key(first_id) { pub fn evict_tenant_data(&self, max_size: usize) {
let child = curr.children.get(first_id).unwrap(); // Calculate used size and collect leaves
let mut stack = vec![Arc::clone(&self.root)];
let mut used_size_per_tenant: HashMap<String, usize> = HashMap::new();
let mut pq = BinaryHeap::new();
let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids); while let Some(curr) = stack.pop() {
for tenant in curr.tenant_last_access_time.iter() {
let size = used_size_per_tenant
.entry(tenant.key().clone())
.or_insert(0);
*size += curr.text.read().unwrap().chars().count();
}
if prefix_len == child.ids.len() { for child in curr.children.iter() {
if child.count == 1 { stack.push(Arc::clone(child.value()));
// If count will become 0, remove the child }
let child = curr.children.get_mut(first_id).unwrap();
child.count -= 1; // Add leaves to priority queue
curr.children.remove(first_id); for tenant in Tree::leaf_of(&curr) {
break; if let Some(timestamp) = curr.tenant_last_access_time.get(&tenant) {
} else { pq.push(Reverse(EvictionEntry {
// Otherwise decrement count and continue timestamp: *timestamp,
let child = curr.children.get_mut(first_id).unwrap(); tenant: tenant.clone(),
node: Arc::clone(&curr),
}));
}
}
}
println!("Before eviction - Used size per tenant:");
for (tenant, size) in &used_size_per_tenant {
println!("Tenant: {}, Size: {}", tenant, size);
}
child.count -= 1; // Process eviction
curr = child; while let Some(Reverse(entry)) = pq.pop() {
curr_idx += prefix_len; let EvictionEntry { tenant, node, .. } = entry;
if let Some(&used_size) = used_size_per_tenant.get(&tenant) {
if used_size <= max_size {
continue;
}
// Update used size
if let Some(size) = used_size_per_tenant.get_mut(&tenant) {
*size -= node.text.read().unwrap().chars().count();
}
// Decrement when removing tenant from node
if node.tenant_last_access_time.contains_key(&tenant) {
self.tenant_char_count
.entry(tenant.clone())
.and_modify(|count| {
if *count > 0 {
*count -= node.text.read().unwrap().chars().count();
}
});
}
// Remove tenant from node
node.tenant_last_access_time.remove(&tenant);
// Remove empty nodes
if node.children.is_empty() && node.tenant_last_access_time.is_empty() {
if let Some(parent) = node.parent.write().unwrap().as_ref() {
let first_char = node.text.read().unwrap().chars().next().unwrap();
parent.children.remove(&first_char);
}
}
// Add parent to queue if it becomes a leaf
if let Some(parent) = node.parent.read().unwrap().as_ref() {
if Tree::leaf_of(parent).contains(&tenant) {
if let Some(timestamp) = parent.tenant_last_access_time.get(&tenant) {
pq.push(Reverse(EvictionEntry {
timestamp: *timestamp,
tenant: tenant.clone(),
node: Arc::clone(parent),
}));
}
} }
} else {
panic!("No match found for {:?}", input_ids);
} }
} else {
panic!("No match found for {:?}", input_ids);
} }
} }
println!("\nAfter eviction - Used size per tenant:");
for (tenant, size) in &used_size_per_tenant {
println!("Tenant: {}, Size: {}", tenant, size);
}
}
pub fn get_tenant_char_count(&self) -> HashMap<String, usize> {
self.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect()
}
pub fn get_smallest_tenant(&self) -> String {
// Return a placeholder if there are no tenants
if self.tenant_char_count.is_empty() {
return "empty".to_string();
}
// Find the tenant with minimum char count
let mut min_tenant = None;
let mut min_count = usize::MAX;
for entry in self.tenant_char_count.iter() {
let tenant = entry.key();
let count = *entry.value();
if count < min_count {
min_count = count;
min_tenant = Some(tenant.clone());
}
}
// Return the found tenant or "empty" if somehow none was found
min_tenant.unwrap_or_else(|| "empty".to_string())
}
pub fn get_used_size_per_tenant(&self) -> HashMap<String, usize> {
// perform a DFS to traverse all nodes and calculate the total size used by each tenant
let mut used_size_per_tenant: HashMap<String, usize> = HashMap::new();
let mut stack = vec![Arc::clone(&self.root)];
while let Some(curr) = stack.pop() {
let text_count = curr.text.read().unwrap().chars().count();
for tenant in curr.tenant_last_access_time.iter() {
let size = used_size_per_tenant
.entry(tenant.key().clone())
.or_insert(0);
*size += text_count;
}
for child in curr.children.iter() {
stack.push(Arc::clone(child.value()));
}
}
used_size_per_tenant
}
fn node_to_string(node: &NodeRef, prefix: &str, is_last: bool) -> String {
let mut result = String::new();
// Add prefix and branch character
result.push_str(prefix);
result.push_str(if is_last { "└── " } else { "├── " });
// Add node text
let node_text = node.text.read().unwrap();
result.push_str(&format!("'{}' [", node_text));
// Add tenant information with timestamps
let mut tenant_info = Vec::new();
for entry in node.tenant_last_access_time.iter() {
let tenant_id = entry.key();
let timestamp_ms = entry.value();
// Convert milliseconds to seconds and remaining milliseconds
let seconds = (timestamp_ms / 1000) as u64;
let millis = (timestamp_ms % 1000) as u32;
// Create SystemTime from Unix timestamp
let system_time = UNIX_EPOCH + Duration::from_secs(seconds);
// Format time as HH:MM:SS.mmm
let datetime = system_time.duration_since(UNIX_EPOCH).unwrap();
let hours = (datetime.as_secs() % 86400) / 3600;
let minutes = (datetime.as_secs() % 3600) / 60;
let seconds = datetime.as_secs() % 60;
tenant_info.push(format!(
"{} | {:02}:{:02}:{:02}.{:03}",
tenant_id, hours, minutes, seconds, millis
));
}
result.push_str(&tenant_info.join(", "));
result.push_str("]\n");
// Process children
let children: Vec<_> = node.children.iter().collect();
let child_count = children.len();
for (i, entry) in children.iter().enumerate() {
let is_last_child = i == child_count - 1;
let new_prefix = format!("{}{}", prefix, if is_last { " " } else { "│ " });
result.push_str(&Tree::node_to_string(
entry.value(),
&new_prefix,
is_last_child,
));
}
result
} }
// for debug
pub fn pretty_print(&self) { pub fn pretty_print(&self) {
println!("RadixTree:"); if self.root.children.is_empty() {
Self::print_node(&self.root, String::from("")); return;
}
let mut result = String::new();
let children: Vec<_> = self.root.children.iter().collect();
let child_count = children.len();
for (i, entry) in children.iter().enumerate() {
let is_last = i == child_count - 1;
result.push_str(&Tree::node_to_string(entry.value(), "", is_last));
}
println!("{result}");
return;
}
}
// Unit tests
#[cfg(test)]
mod tests {
use std::time::Instant;
use rand::Rng;
use super::*;
#[test]
fn test_get_smallest_tenant() {
let tree = Tree::new();
// Test empty tree
assert_eq!(tree.get_smallest_tenant(), "empty");
// Insert data for tenant1 - "ap" + "icot" = 6 chars
tree.insert("ap", "tenant1");
tree.insert("icot", "tenant1");
// Insert data for tenant2 - "cat" = 3 chars
tree.insert("cat", "tenant2");
// Test - tenant2 should be smallest with 3 chars vs 6 chars
assert_eq!(
tree.get_smallest_tenant(),
"tenant2",
"Expected tenant2 to be smallest with 3 characters"
);
// Insert overlapping data for tenant3 and tenant4 to test equal counts
// tenant3: "do" = 2 chars
// tenant4: "hi" = 2 chars
tree.insert("do", "tenant3");
tree.insert("hi", "tenant4");
// Test - should return either tenant3 or tenant4 (both have 2 chars)
let smallest = tree.get_smallest_tenant();
assert!(
smallest == "tenant3" || smallest == "tenant4",
"Expected either tenant3 or tenant4 (both have 2 characters), got {}",
smallest
);
// Add more text to tenant4 to make it larger
tree.insert("hello", "tenant4"); // Now tenant4 has "hi" + "hello" = 6 chars
// Now tenant3 should be smallest (2 chars vs 6 chars for tenant4)
assert_eq!(
tree.get_smallest_tenant(),
"tenant3",
"Expected tenant3 to be smallest with 2 characters"
);
// Test eviction
tree.evict_tenant_data(3); // This should evict tenants with more than 3 chars
let post_eviction_smallest = tree.get_smallest_tenant();
println!("Smallest tenant after eviction: {}", post_eviction_smallest);
}
#[test]
fn test_tenant_char_count() {
let tree = Tree::new();
// Phase 1: Initial insertions
tree.insert("apple", "tenant1");
tree.insert("apricot", "tenant1");
tree.insert("banana", "tenant1");
tree.insert("amplify", "tenant2");
tree.insert("application", "tenant2");
let computed_sizes = tree.get_used_size_per_tenant();
let maintained_counts: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
println!("Phase 1 - Maintained vs Computed counts:");
println!(
"Maintained: {:?}\nComputed: {:?}",
maintained_counts, computed_sizes
);
assert_eq!(
maintained_counts, computed_sizes,
"Phase 1: Initial insertions"
);
// Phase 2: Additional insertions
tree.insert("apartment", "tenant1");
tree.insert("appetite", "tenant2");
tree.insert("ball", "tenant1");
tree.insert("box", "tenant2");
let computed_sizes = tree.get_used_size_per_tenant();
let maintained_counts: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
println!("Phase 2 - Maintained vs Computed counts:");
println!(
"Maintained: {:?}\nComputed: {:?}",
maintained_counts, computed_sizes
);
assert_eq!(
maintained_counts, computed_sizes,
"Phase 2: Additional insertions"
);
// Phase 3: Overlapping insertions
tree.insert("zebra", "tenant1");
tree.insert("zebra", "tenant2");
tree.insert("zero", "tenant1");
tree.insert("zero", "tenant2");
let computed_sizes = tree.get_used_size_per_tenant();
let maintained_counts: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
println!("Phase 3 - Maintained vs Computed counts:");
println!(
"Maintained: {:?}\nComputed: {:?}",
maintained_counts, computed_sizes
);
assert_eq!(
maintained_counts, computed_sizes,
"Phase 3: Overlapping insertions"
);
// Phase 4: Eviction test
tree.evict_tenant_data(10);
let computed_sizes = tree.get_used_size_per_tenant();
let maintained_counts: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
println!("Phase 4 - Maintained vs Computed counts:");
println!(
"Maintained: {:?}\nComputed: {:?}",
maintained_counts, computed_sizes
);
assert_eq!(maintained_counts, computed_sizes, "Phase 4: After eviction");
}
fn random_string(len: usize) -> String {
Alphanumeric.sample_string(&mut thread_rng(), len)
}
#[test]
fn test_cold_start() {
let tree = Tree::new();
let (matched_text, tenant) = tree.prefix_match("hello");
assert_eq!(matched_text, "");
assert_eq!(tenant, "empty");
}
#[test]
fn test_exact_match_seq() {
let tree = Tree::new();
tree.insert("hello", "tenant1");
tree.pretty_print();
tree.insert("apple", "tenant2");
tree.pretty_print();
tree.insert("banana", "tenant3");
tree.pretty_print();
let (matched_text, tenant) = tree.prefix_match("hello");
assert_eq!(matched_text, "hello");
assert_eq!(tenant, "tenant1");
let (matched_text, tenant) = tree.prefix_match("apple");
assert_eq!(matched_text, "apple");
assert_eq!(tenant, "tenant2");
let (matched_text, tenant) = tree.prefix_match("banana");
assert_eq!(matched_text, "banana");
assert_eq!(tenant, "tenant3");
}
#[test]
fn test_exact_match_concurrent() {
let tree = Arc::new(Tree::new());
// spawn 3 threads for insert
let tree_clone = Arc::clone(&tree);
let texts = vec!["hello", "apple", "banana"];
let tenants = vec!["tenant1", "tenant2", "tenant3"];
let mut handles = vec![];
for i in 0..3 {
let tree_clone = Arc::clone(&tree_clone);
let text = texts[i];
let tenant = tenants[i];
let handle = thread::spawn(move || {
tree_clone.insert(text, tenant);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
// spawn 3 threads for match
let mut handles = vec![];
let tree_clone = Arc::clone(&tree);
for i in 0..3 {
let tree_clone = Arc::clone(&tree_clone);
let text = texts[i];
let tenant = tenants[i];
let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
assert_eq!(matched_text, text);
assert_eq!(matched_tenant, tenant);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_partial_match_concurrent() {
let tree = Arc::new(Tree::new());
// spawn 3 threads for insert
let tree_clone = Arc::clone(&tree);
let texts = vec!["apple", "apabc", "acbdeds"];
let mut handles = vec![];
for i in 0..3 {
let tree_clone = Arc::clone(&tree_clone);
let text = texts[i];
let tenant = "tenant0";
let handle = thread::spawn(move || {
tree_clone.insert(text, tenant);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
// spawn 3 threads for match
let mut handles = vec![];
let tree_clone = Arc::clone(&tree);
for i in 0..3 {
let tree_clone = Arc::clone(&tree_clone);
let text = texts[i];
let tenant = "tenant0";
let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
assert_eq!(matched_text, text);
assert_eq!(matched_tenant, tenant);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
} }
fn print_node(node: &Node, prefix: String) { #[test]
// Print current node info with "count" word fn test_group_prefix_insert_match_concurrent() {
println!("{}└── {:?} (count: {})", prefix, node.ids, node.count); let prefix = vec![
"Clock strikes midnight, I'm still wide awake",
"Got dreams bigger than these city lights",
"Time waits for no one, gotta make my move",
"Started from the bottom, that's no metaphor",
];
let suffix = vec![
"Got too much to prove, ain't got time to lose",
"History in the making, yeah, you can't erase this",
];
let tree = Arc::new(Tree::new());
// Print children with proper prefixes let mut handles = vec![];
for (i, child) in node.children.values().enumerate() {
let is_last = i == node.children.len() - 1; for i in 0..prefix.len() {
let child_prefix = if is_last { for j in 0..suffix.len() {
format!("{} ", prefix) // Add space for last child let tree_clone = Arc::clone(&tree);
} else { let text = format!("{} {}", prefix[i], suffix[j]);
format!("{}│ ", prefix) // Add vertical line for other children let tenant = format!("tenant{}", i);
};
Self::print_node(child, child_prefix); let handle = thread::spawn(move || {
tree_clone.insert(&text, &tenant);
});
handles.push(handle);
}
}
// wait
for handle in handles {
handle.join().unwrap();
} }
tree.pretty_print();
// check matching using multi threads
let mut handles = vec![];
for i in 0..prefix.len() {
let tree_clone = Arc::clone(&tree);
let text = prefix[i];
let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
let tenant = format!("tenant{}", i);
assert_eq!(matched_text, text);
assert_eq!(matched_tenant, tenant);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_mixed_concurrent_insert_match() {
// ensure it does not deadlock instead of doing correctness check
let prefix = vec![
"Clock strikes midnight, I'm still wide awake",
"Got dreams bigger than these city lights",
"Time waits for no one, gotta make my move",
"Started from the bottom, that's no metaphor",
];
let suffix = vec![
"Got too much to prove, ain't got time to lose",
"History in the making, yeah, you can't erase this",
];
let tree = Arc::new(Tree::new());
let mut handles = vec![];
for i in 0..prefix.len() {
for j in 0..suffix.len() {
let tree_clone = Arc::clone(&tree);
let text = format!("{} {}", prefix[i], suffix[j]);
let tenant = format!("tenant{}", i);
let handle = thread::spawn(move || {
tree_clone.insert(&text, &tenant);
});
handles.push(handle);
}
}
// check matching using multi threads
for i in 0..prefix.len() {
let tree_clone = Arc::clone(&tree);
let text = prefix[i];
let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(text);
});
handles.push(handle);
}
// wait
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_utf8_split_seq() {
// The string should be indexed and splitted by a utf-8 value basis instead of byte basis
// use .chars() to get the iterator of the utf-8 value
let tree = Arc::new(Tree::new());
let test_pairs = vec![
("你好嗎", "tenant1"),
("你好喔", "tenant2"),
("你心情好嗎", "tenant3"),
];
// Insert sequentially
for i in 0..test_pairs.len() {
let text = test_pairs[i].0;
let tenant = test_pairs[i].1;
tree.insert(text, tenant);
}
tree.pretty_print();
// Test sequentially
for i in 0..test_pairs.len() {
let (matched_text, matched_tenant) = tree.prefix_match(test_pairs[i].0);
assert_eq!(matched_text, test_pairs[i].0);
assert_eq!(matched_tenant, test_pairs[i].1);
}
}
#[test]
fn test_utf8_split_concurrent() {
let tree = Arc::new(Tree::new());
let test_pairs = vec![
("你好嗎", "tenant1"),
("你好喔", "tenant2"),
("你心情好嗎", "tenant3"),
];
// Create multiple threads for insertion
let mut handles = vec![];
for i in 0..test_pairs.len() {
let tree_clone = Arc::clone(&tree);
let text = test_pairs[i].0.to_string();
let tenant = test_pairs[i].1.to_string();
let handle = thread::spawn(move || {
tree_clone.insert(&text, &tenant);
});
handles.push(handle);
}
// Wait for all insertions to complete
for handle in handles {
handle.join().unwrap();
}
tree.pretty_print();
// Create multiple threads for matching
let mut handles = vec![];
for i in 0..test_pairs.len() {
let tree_clone = Arc::clone(&tree);
let text = test_pairs[i].0.to_string();
let tenant = test_pairs[i].1.to_string();
let handle = thread::spawn(move || {
let (matched_text, matched_tenant) = tree_clone.prefix_match(&text);
assert_eq!(matched_text, text);
assert_eq!(matched_tenant, tenant);
});
handles.push(handle);
}
// Wait for all matches to complete
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_simple_eviction() {
let tree = Tree::new();
let max_size = 5;
// Insert strings for both tenants
tree.insert("hello", "tenant1"); // size 5
tree.insert("hello", "tenant2"); // size 5
thread::sleep(Duration::from_millis(10));
tree.insert("world", "tenant2"); // size 5, total for tenant2 = 10
tree.pretty_print();
// Verify initial sizes
let sizes_before = tree.get_used_size_per_tenant();
assert_eq!(sizes_before.get("tenant1").unwrap(), &5); // "hello" = 5
assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10
// Evict - should remove "hello" from tenant2 as it's the oldest
tree.evict_tenant_data(max_size);
tree.pretty_print();
// Verify sizes after eviction
let sizes_after = tree.get_used_size_per_tenant();
assert_eq!(sizes_after.get("tenant1").unwrap(), &5); // Should be unchanged
assert_eq!(sizes_after.get("tenant2").unwrap(), &5); // Only "world" remains
// Verify "world" remains for tenant2
let (matched, tenant) = tree.prefix_match("world");
assert_eq!(matched, "world");
assert_eq!(tenant, "tenant2");
}
#[test]
fn test_advanced_eviction() {
let tree = Tree::new();
// Set limits for each tenant
let max_size: usize = 100;
// Define prefixes
let prefixes = vec!["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"];
// Insert strings with shared prefixes
for i in 0..100 {
for (j, prefix) in prefixes.iter().enumerate() {
let random_suffix = random_string(10);
let text = format!("{}{}", prefix, random_suffix);
let tenant = format!("tenant{}", j + 1);
tree.insert(&text, &tenant);
}
}
// Perform eviction
tree.evict_tenant_data(max_size);
// Check sizes after eviction
let sizes_after = tree.get_used_size_per_tenant();
// Verify all tenants are under their size limits
for (tenant, &size) in sizes_after.iter() {
assert!(
size <= max_size,
"Tenant {} exceeds size limit. Current size: {}, Limit: {}",
tenant,
size,
max_size
);
}
}
#[test]
fn test_concurrent_operations_with_eviction() {
// Ensure eviction works fine with concurrent insert and match operations for a given period
let tree = Arc::new(Tree::new());
let mut handles = vec![];
let test_duration = Duration::from_secs(10);
let start_time = Instant::now();
let max_size = 100; // Single max size for all tenants
// Spawn eviction thread
{
let tree = Arc::clone(&tree);
let handle = thread::spawn(move || {
while start_time.elapsed() < test_duration {
// Run eviction
tree.evict_tenant_data(max_size);
// Sleep for 5 seconds
thread::sleep(Duration::from_secs(5));
}
});
handles.push(handle);
}
// Spawn 4 worker threads
for thread_id in 0..4 {
let tree = Arc::clone(&tree);
let handle = thread::spawn(move || {
let mut rng = rand::thread_rng();
let tenant = format!("tenant{}", thread_id + 1);
let prefix = format!("prefix{}", thread_id);
while start_time.elapsed() < test_duration {
// Random decision: match or insert (70% match, 30% insert)
if rng.gen_bool(0.7) {
// Perform match operation
let random_len = rng.gen_range(3..10);
let search_str = format!("{}{}", prefix, random_string(random_len));
let (matched, _) = tree.prefix_match(&search_str);
} else {
// Perform insert operation
let random_len = rng.gen_range(5..15);
let insert_str = format!("{}{}", prefix, random_string(random_len));
tree.insert(&insert_str, &tenant);
// println!("Thread {} inserted: {}", thread_id, insert_str);
}
// Small random sleep to vary timing
thread::sleep(Duration::from_millis(rng.gen_range(10..100)));
}
});
handles.push(handle);
}
// Wait for all threads to complete
for handle in handles {
handle.join().unwrap();
}
// final eviction
tree.evict_tenant_data(max_size);
// Final size check
let final_sizes = tree.get_used_size_per_tenant();
println!("Final sizes after test completion: {:?}", final_sizes);
// Verify all tenants are under limit
for (_, &size) in final_sizes.iter() {
assert!(
size <= max_size,
"Tenant exceeds size limit. Final size: {}, Limit: {}",
size,
max_size
);
}
}
#[test]
fn test_leaf_of() {
let tree = Tree::new();
// Single node
tree.insert("hello", "tenant1");
let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap());
assert_eq!(leaves, vec!["tenant1"]);
// Node with multiple tenants
tree.insert("hello", "tenant2");
let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap());
assert_eq!(leaves.len(), 2);
assert!(leaves.contains(&"tenant1".to_string()));
assert!(leaves.contains(&"tenant2".to_string()));
// Non-leaf node
tree.insert("hi", "tenant1");
let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap());
assert!(leaves.is_empty());
}
#[test]
fn test_get_used_size_per_tenant() {
let tree = Tree::new();
// Single tenant
tree.insert("hello", "tenant1");
tree.insert("world", "tenant1");
let sizes = tree.get_used_size_per_tenant();
tree.pretty_print();
println!("{:?}", sizes);
assert_eq!(sizes.get("tenant1").unwrap(), &10); // "hello" + "world"
// Multiple tenants sharing nodes
tree.insert("hello", "tenant2");
tree.insert("help", "tenant2");
let sizes = tree.get_used_size_per_tenant();
tree.pretty_print();
println!("{:?}", sizes);
assert_eq!(sizes.get("tenant1").unwrap(), &10);
assert_eq!(sizes.get("tenant2").unwrap(), &6); // "hello" + "p"
// UTF-8 characters
tree.insert("你好", "tenant3");
let sizes = tree.get_used_size_per_tenant();
tree.pretty_print();
println!("{:?}", sizes);
assert_eq!(sizes.get("tenant3").unwrap(), &2); // 2 Chinese characters
tree.pretty_print();
} }
} }
use sglang_router_rs::tree::RadixTree;
#[test]
fn test_new_tree() {
let tree = RadixTree::new();
assert_eq!(tree.root.count, 0);
assert!(tree.root.children.is_empty());
assert!(tree.root.ids.is_empty());
}
#[test]
fn test_single_insertion() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
assert_eq!(tree.root.count, 1);
assert_eq!(tree.root.children.len(), 1);
assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]);
assert_eq!(tree.root.children[&1].count, 1);
}
#[test]
fn test_multiple_insertions_no_split() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.insert(&[4, 5, 6]);
assert_eq!(tree.root.count, 2);
assert_eq!(tree.root.children.len(), 2);
assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]);
assert_eq!(tree.root.children[&4].ids, vec![4, 5, 6]);
}
#[test]
fn test_insertion_with_split() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
tree.insert(&[1, 2, 5, 6]);
assert_eq!(tree.root.count, 2);
assert_eq!(tree.root.children.len(), 1);
assert_eq!(tree.root.children[&1].ids, vec![1, 2]);
assert_eq!(tree.root.children[&1].children.len(), 2);
assert_eq!(tree.root.children[&1].children[&3].ids, vec![3, 4]);
assert_eq!(tree.root.children[&1].children[&5].ids, vec![5, 6]);
}
#[test]
fn test_prefix_match_exact() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
assert_eq!(tree.prefix_match(&[1, 2, 3, 4]), &[1, 2, 3, 4]);
}
#[test]
fn test_prefix_match_partial() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
assert_eq!(tree.prefix_match(&[1, 2, 3, 5]), &[1, 2, 3]);
assert_eq!(tree.prefix_match(&[1, 2, 5]), &[1, 2]);
assert_eq!(tree.prefix_match(&[1, 5]), &[1]);
}
#[test]
fn test_prefix_match_no_match() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
let empty_slices: &[u32] = &[];
assert_eq!(tree.prefix_match(&[5, 6, 7]), empty_slices);
}
#[test]
fn test_delete_leaf() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.delete(&[1, 2, 3]);
assert_eq!(tree.root.count, 0);
assert_eq!(tree.root.children.len(), 0);
}
#[test]
fn test_delete_with_siblings() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.insert(&[1, 2, 4]);
tree.delete(&[1, 2, 3]);
assert_eq!(tree.root.count, 1);
assert_eq!(tree.root.children[&1].children[&4].ids, vec![4]);
}
#[test]
fn test_multiple_operations() {
let mut tree = RadixTree::new();
// Insert several paths
tree.insert(&[1, 2, 3]);
tree.insert(&[1, 2, 4]);
tree.insert(&[1, 5, 6]);
// Verify structure
assert_eq!(tree.root.count, 3);
assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2, 3]);
assert_eq!(tree.prefix_match(&[1, 2, 4]), &[1, 2, 4]);
assert_eq!(tree.prefix_match(&[1, 5, 6]), &[1, 5, 6]);
// Delete and verify
tree.delete(&[1, 2, 3]);
assert_eq!(tree.root.count, 2);
assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2]); // Now only matches prefix
}
#[test]
#[should_panic(expected = "No match found")]
fn test_delete_nonexistent() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.delete(&[4, 5, 6]); // Should panic
}
#[test]
fn test_empty_input() {
let mut tree = RadixTree::new();
let empty_slice: &[u32] = &[];
tree.insert(empty_slice);
assert_eq!(tree.prefix_match(empty_slice), empty_slice);
tree.delete(empty_slice); // Should not panic
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment