Unverified Commit cbedd1db authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] cache-aware load-balancing router v1 (#2114)

parent ad47749b
import itertools import itertools
import json import json
import os
import random import random
import string import string
import threading import threading
import time import time
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path
from typing import Union
from tqdm import tqdm
import sglang as sgl import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenize from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
from sglang.utils import dump_state_text from sglang.utils import dump_state_text
random.seed(42)
def gen_prompt(tokenizer, token_num): def gen_prompt(tokenizer, token_num):
all_available_tokens = list(tokenizer.get_vocab().values()) all_available_tokens = list(tokenizer.get_vocab().values())
...@@ -24,12 +27,34 @@ def gen_prompt(tokenizer, token_num): ...@@ -24,12 +27,34 @@ def gen_prompt(tokenizer, token_num):
return ret return ret
def get_cache_path(args):
# Create cache directory under ~/.cache/sglang
cache_dir = Path.home() / ".cache" / "sglang"
# Create a unique cache filename based on the arguments that affect generation
cache_key = f"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json"
return cache_dir / cache_key
def gen_arguments(args, tokenizer): def gen_arguments(args, tokenizer):
multi_qas = [ cache_path = get_cache_path(args)
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
for _ in range(args.num_qa) # Try to load from cache first
] if cache_path.exists():
for i in range(args.num_qa): print(f"Loading cached arguments from {cache_path}")
with open(cache_path, "r") as f:
return json.load(f)
print("Generating new arguments...")
# First progress bar for system prompts
multi_qas = []
for _ in tqdm(range(args.num_qa), desc="Generating system prompts"):
multi_qas.append(
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
)
# Nested progress bars for QA pairs
for i in tqdm(range(args.num_qa), desc="Generating QA pairs"):
qas = multi_qas[i]["qas"] qas = multi_qas[i]["qas"]
for j in range(args.turns): for j in range(args.turns):
qas.append( qas.append(
...@@ -38,6 +63,13 @@ def gen_arguments(args, tokenizer): ...@@ -38,6 +63,13 @@ def gen_arguments(args, tokenizer):
"new_tokens": args.len_a, "new_tokens": args.len_a,
} }
) )
# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(cache_path, "w") as f:
json.dump(multi_qas, f)
print(f"Cached arguments saved to {cache_path}")
return multi_qas return multi_qas
...@@ -45,7 +77,7 @@ def gen_arguments(args, tokenizer): ...@@ -45,7 +77,7 @@ def gen_arguments(args, tokenizer):
def multi_turns(s, system_prompt, qas): def multi_turns(s, system_prompt, qas):
s += system_prompt s += system_prompt
for qa in qas: for i, qa in enumerate(qas):
s += qa["prompt"] s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True) s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
...@@ -62,7 +94,7 @@ def main(args): ...@@ -62,7 +94,7 @@ def main(args):
multi_qas, multi_qas,
temperature=0, temperature=0,
backend=backend, backend=backend,
num_threads=args.parallel, num_threads="auto",
progress_bar=True, progress_bar=True,
) )
latency = time.time() - tic latency = time.time() - tic
...@@ -75,7 +107,6 @@ def main(args): ...@@ -75,7 +107,6 @@ def main(args):
value = { value = {
"task": "multi_turn_system_prompt_chat", "task": "multi_turn_system_prompt_chat",
"backend": args.backend, "backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3), "latency": round(latency, 3),
"num_requests": args.num_qa, "num_requests": args.num_qa,
"num_turns": args.turns, "num_turns": args.turns,
......
...@@ -727,9 +727,9 @@ def sample_generated_shared_prefix_requests( ...@@ -727,9 +727,9 @@ def sample_generated_shared_prefix_requests(
total_input_tokens = 0 total_input_tokens = 0
total_output_tokens = 0 total_output_tokens = 0
for group_idx in range(num_groups): for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
system_prompt = system_prompts[group_idx] system_prompt = system_prompts[group_idx]
for prompt_idx in range(prompts_per_group): for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions"):
question = questions[group_idx * prompts_per_group + prompt_idx] question = questions[group_idx * prompts_per_group + prompt_idx]
full_prompt = f"{system_prompt}\n\n{question}" full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt)) prompt_len = len(tokenizer.encode(full_prompt))
......
...@@ -48,9 +48,13 @@ def run_eval(args): ...@@ -48,9 +48,13 @@ def run_eval(args):
# Select backend # Select backend
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}")) set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
# Read data if args.data_path is None:
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" # Read data
filename = download_and_cache_file(url) url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
else:
filename = args.data_path
lines = list(read_jsonl(filename)) lines = list(read_jsonl(filename))
# Construct prompts # Construct prompts
......
...@@ -591,6 +591,20 @@ dependencies = [ ...@@ -591,6 +591,20 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]] [[package]]
name = "deranged" name = "deranged"
version = "0.3.11" version = "0.3.11"
...@@ -904,6 +918,12 @@ dependencies = [ ...@@ -904,6 +918,12 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.1" version = "0.15.1"
...@@ -1226,7 +1246,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -1226,7 +1246,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown", "hashbrown 0.15.1",
] ]
[[package]] [[package]]
...@@ -2097,7 +2117,9 @@ dependencies = [ ...@@ -2097,7 +2117,9 @@ dependencies = [
"actix-web", "actix-web",
"bytes", "bytes",
"clap", "clap",
"dashmap",
"futures-util", "futures-util",
"http 1.1.0",
"pyo3", "pyo3",
"rand", "rand",
"reqwest", "reqwest",
......
...@@ -24,6 +24,8 @@ futures-util = "0.3" ...@@ -24,6 +24,8 @@ futures-util = "0.3"
serde_json = "1.0" serde_json = "1.0"
pyo3 = { version = "0.22.5", features = ["extension-module"] } pyo3 = { version = "0.22.5", features = ["extension-module"] }
tokenizers = { version = "0.20.3", features = ["http"] } tokenizers = { version = "0.20.3", features = ["http"] }
dashmap = "6.1.0"
http = "1.1.0"
[profile.release] [profile.release]
lto = "thin" lto = "thin"
......
...@@ -46,6 +46,9 @@ pip install <path-to-wheel> ...@@ -46,6 +46,9 @@ pip install <path-to-wheel>
#### Option B: Development Mode #### Option B: Development Mode
For development purposes, you can install the package in editable mode: For development purposes, you can install the package in editable mode:
Warning: Using editable python binding can suffer from performance degradation!! Please build a fresh wheel for every update if you want to test performance.
```bash ```bash
pip install -e . pip install -e .
``` ```
......
from sglang_router import PolicyType, Router
router = Router(
worker_urls=[
"http://localhost:30000",
"http://localhost:30001",
]
)
router.start()
import argparse
import os
import signal
import subprocess
import sys
import time
from typing import Dict, List
import requests
from sglang_router import PolicyType, Router
# Global processes list for cleanup
_processes: List[subprocess.Popen] = []
def cleanup_processes(signum=None, frame=None):
"""Cleanup function to kill all worker processes."""
print("\nCleaning up processes...")
for process in _processes:
try:
# Kill the entire process group
pgid = os.getpgid(process.pid)
os.killpg(pgid, signal.SIGKILL)
process.wait()
except:
pass
sys.exit(1)
# Register signal handlers
signal.signal(signal.SIGINT, cleanup_processes)
signal.signal(signal.SIGTERM, cleanup_processes)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Launch SGLang Router Server")
parser.add_argument(
"--host", type=str, default="localhost", help="Host address to bind the server"
)
parser.add_argument(
"--port", type=int, default=30000, help="Base port number for workers"
)
parser.add_argument(
"--dp",
type=int,
default=2,
help="Number of worker processes (degree of parallelism)",
)
parser.add_argument(
"--model-path", type=str, required=True, help="Path to the model"
)
parser.add_argument(
"--local-tokenizer-path",
type=str,
required=True,
help="Path to the local tokenizer",
)
return parser.parse_args()
def launch_workers(args) -> tuple[List[subprocess.Popen], List[str]]:
"""Launch all worker processes concurrently using subprocess."""
processes = []
worker_urls = []
# Launch each worker process
for i in range(args.dp):
port = args.port + i
url = f"http://{args.host}:{port}"
worker_urls.append(url)
# TODO: replace this with launch_server, and move this file to sglang/ because it depends on sglang
# We don't
command = f"export CUDA_VISIBLE_DEVICES={i}; python -m sglang.launch_server --model-path {args.model_path} --host {args.host} --port {port}"
print(command)
process = subprocess.Popen(command, shell=True)
processes.append(process)
_processes.append(process) # Add to global list for cleanup
return processes, worker_urls
def wait_for_healthy_workers(worker_urls: List[str], timeout: int = 300) -> bool:
"""Block until all workers are healthy or timeout is reached."""
start_time = time.time()
healthy_workers: Dict[str, bool] = {url: False for url in worker_urls}
while time.time() - start_time < timeout:
print("checking healthiness...")
all_healthy = True
for url in worker_urls:
if not healthy_workers[url]: # Only check workers that aren't healthy yet
try:
response = requests.get(f"{url}/health")
if response.status_code == 200:
print(f"Worker at {url} is healthy")
healthy_workers[url] = True
else:
all_healthy = False
except requests.RequestException:
all_healthy = False
if all_healthy:
print("All workers are healthy!")
return True
time.sleep(5)
# If we get here, we've timed out
unhealthy_workers = [url for url, healthy in healthy_workers.items() if not healthy]
print(f"Timeout waiting for workers: {unhealthy_workers}")
return False
def main():
"""Main function to launch the router and workers."""
args = parse_args()
processes = None
try:
# Launch all workers concurrently
processes, worker_urls = launch_workers(args)
# Block until all workers are healthy
if not wait_for_healthy_workers(worker_urls):
raise RuntimeError("Failed to start all workers")
# Initialize and start the router
router = Router(
worker_urls=worker_urls,
policy=PolicyType.ApproxTree,
tokenizer_path=args.local_tokenizer_path,
)
print("Starting router...")
router.start()
# Keep the main process running
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("\nShutting down...")
except Exception as e:
print(f"Error: {e}")
finally:
# Cleanup: Kill all worker processes
if processes:
for process in processes:
process.kill()
if __name__ == "__main__":
main()
import argparse
import dataclasses
import sys
from typing import List, Optional
from sglang_router import Router
from sglang_router_rs import PolicyType
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
worker_urls: List[str]
host: str = "127.0.0.1"
port: int = 30000
# Routing policy
policy: str = "cache_aware"
cache_threshold: float = 0.5
cache_routing_prob: float = 1.0
eviction_interval: int = 60
max_tree_size: int = 2**24
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
use_router_prefix: bool = False,
exclude_host_port: bool = False,
):
"""
Add router-specific arguments to an argument parser.
Args:
parser: The argument parser to add arguments to
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
"""
prefix = "router-" if use_router_prefix else ""
# Worker configuration
if not exclude_host_port:
parser.add_argument(
"--host",
type=str,
default=RouterArgs.host,
help="Host address to bind the router server",
)
parser.add_argument(
"--port",
type=int,
default=RouterArgs.port,
help="Port number to bind the router server",
)
parser.add_argument(
"--worker-urls",
type=str,
nargs="+",
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
)
# Routing policy configuration
parser.add_argument(
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware"],
help="Load balancing policy to use",
)
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
default=RouterArgs.cache_threshold,
help="Cache threshold (0.0-1.0) for cache-aware routing",
)
parser.add_argument(
f"--{prefix}cache-routing-prob",
type=float,
default=RouterArgs.cache_routing_prob,
help="Probability of using cache-aware routing (0.0-1.0)",
)
parser.add_argument(
f"--{prefix}eviction-interval",
type=int,
default=RouterArgs.eviction_interval,
help="Interval in seconds between cache eviction operations",
)
parser.add_argument(
f"--{prefix}max-tree-size",
type=int,
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
@classmethod
def from_cli_args(
cls, args: argparse.Namespace, use_router_prefix: bool = False
) -> "RouterArgs":
"""
Create RouterArgs instance from parsed command line arguments.
Args:
args: Parsed command line arguments
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
return cls(
worker_urls=args.worker_urls,
host=args.host,
port=args.port,
policy=getattr(args, f"{prefix}policy"),
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
cache_routing_prob=getattr(args, f"{prefix}cache_routing_prob"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
)
def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum."""
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
}
return policy_map[policy_str]
def launch_router(args: argparse.Namespace) -> Optional[Router]:
"""
Launch the SGLang router with the configuration from parsed arguments.
Args:
args: Namespace object containing router configuration
Can be either raw argparse.Namespace or converted RouterArgs
Returns:
Router instance if successful, None if failed
"""
try:
# Convert to RouterArgs if needed
if not isinstance(args, RouterArgs):
router_args = RouterArgs.from_cli_args(args)
else:
router_args = args
router = Router(
worker_urls=router_args.worker_urls,
policy=policy_from_str(router_args.policy),
host=router_args.host,
port=router_args.port,
cache_threshold=router_args.cache_threshold,
cache_routing_prob=router_args.cache_routing_prob,
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
)
router.start()
return router
except Exception as e:
print(f"Error starting router: {e}", file=sys.stderr)
return None
class CustomHelpFormatter(
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
):
"""Custom formatter that preserves both description formatting and shows defaults"""
pass
def parse_router_args(args: List[str]) -> RouterArgs:
"""Parse command line arguments and return RouterArgs instance."""
parser = argparse.ArgumentParser(
description="""SGLang Router - High-performance request distribution across worker nodes
Usage:
This launcher enables starting a router with individual worker instances. It is useful for
multi-node setups or when you want to start workers and router separately.
Examples:
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --cache-routing-prob 0.5
""",
formatter_class=CustomHelpFormatter,
)
RouterArgs.add_cli_args(parser, use_router_prefix=False)
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
def main() -> None:
router_args = parse_router_args(sys.argv[1:])
router = launch_router(router_args)
if router is None:
sys.exit(1)
if __name__ == "__main__":
main()
import argparse
import copy
import multiprocessing as mp
import os
import signal
import sys
import time
from typing import List
import requests
from sglang_router.launch_router import RouterArgs, launch_router
from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs, prepare_server_args
from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback
# Create new process group
def run_server(server_args, dp_rank):
os.setpgrp() # Create new process group
# Set DP_RANK environment variable
os.environ["DP_RANK"] = str(dp_rank)
launch_server(server_args)
def launch_server_process(
server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
"""Launch a single server process with the given args and port."""
server_args = copy.deepcopy(server_args)
server_args.port = worker_port
server_args.base_gpu_id = dp_id * server_args.tp_size
server_args.dp_size = 1
proc = mp.Process(target=run_server, args=(server_args, dp_id))
proc.start()
return proc
def cleanup_processes(processes: List[mp.Process]):
"""Clean up all processes using process groups."""
print("\nCleaning up processes...")
for proc in processes:
if proc.is_alive():
try:
# Kill the entire process group
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
# Give processes some time to terminate gracefully
proc.join(timeout=3)
# If process is still alive, force kill
if proc.is_alive():
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except ProcessLookupError:
pass # Process already terminated
def setup_signal_handlers(cleanup_func):
"""Setup handlers for various termination signals."""
def signal_handler(signum, frame):
cleanup_func()
sys.exit(1)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, "SIGQUIT"):
signal.signal(signal.SIGQUIT, signal_handler)
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint."""
start_time = time.time()
url = f"http://{host}:{port}/health"
while time.time() - start_time < timeout:
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
return True
except requests.exceptions.RequestException:
pass
time.sleep(1)
return False
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += 1
return available_ports
def main():
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn")
parser = argparse.ArgumentParser(
description="Launch SGLang router and server processes"
)
ServerArgs.add_cli_args(parser)
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
parser.add_argument(
"--router-dp-worker-base-port",
type=int,
default=31000,
help="Base port number for data parallel workers",
)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Find available ports for workers
worker_ports = find_available_ports(
args.router_dp_worker_base_port, server_args.dp_size
)
# Start server processes
server_processes = []
try:
# Launch server processes
for i, worker_port in enumerate(worker_ports):
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)
# Setup cleanup handler
setup_signal_handlers(lambda: cleanup_processes(server_processes))
# Wait for all servers to be healthy
all_healthy = True
for port in worker_ports:
if not wait_for_server_health(server_args.host, port):
print(f"Server on port {port} failed to become healthy")
all_healthy = False
break
if not all_healthy:
print("Not all servers are healthy. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
print("All servers are healthy. Starting router...")
# Update router args with worker URLs
router_args.worker_urls = [
f"http://{server_args.host}:{port}" for port in worker_ports
]
# Start the router
router = launch_router(router_args)
if router is None:
print("Failed to start router. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
except KeyboardInterrupt:
print("\nReceived shutdown signal...")
except Exception as e:
print(f"Error occurred: {e}")
print(get_exception_traceback())
finally:
cleanup_processes(server_processes)
if __name__ == "__main__":
main()
...@@ -9,16 +9,23 @@ class Router: ...@@ -9,16 +9,23 @@ class Router:
A high-performance router for distributing requests across worker nodes. A high-performance router for distributing requests across worker nodes.
Args: Args:
worker_urls: List of URLs for worker nodes that will handle requests worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
policy: Load balancing policy to use. Options: policy: Load balancing policy to use. Options:
- PolicyType.Random: Randomly select workers - PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion - PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.ApproxTree: Tree-based routing using tokenizer similarity - PolicyType.CacheAware: Distribute requests in cache-aware fashion
host: Host address to bind the router server host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server port: Port number to bind the router server. Default: 3001
tokenizer_path: Path to tokenizer model file (required for ApproxTree policy) cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
cache_threshold: Caching threshold value between 0-1 if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5
cache_routing_prob: Probability of using cache-aware routing (0.0-1.0). Default 1.0 for
full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven
workloads, use a lower value to better distribute requests
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
""" """
def __init__( def __init__(
...@@ -27,17 +34,20 @@ class Router: ...@@ -27,17 +34,20 @@ class Router:
policy: PolicyType = PolicyType.RoundRobin, policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 3001, port: int = 3001,
tokenizer_path: Optional[str] = None,
cache_threshold: float = 0.50, cache_threshold: float = 0.50,
cache_routing_prob: float = 1.0,
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
): ):
self._router = _Router( self._router = _Router(
worker_urls=worker_urls, worker_urls=worker_urls,
policy=policy, policy=policy,
host=host, host=host,
port=port, port=port,
tokenizer_path=tokenizer_path,
cache_threshold=cache_threshold, cache_threshold=cache_threshold,
cache_routing_prob=cache_routing_prob,
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
) )
def start(self) -> None: def start(self) -> None:
......
// Python Binding
use pyo3::prelude::*; use pyo3::prelude::*;
pub mod router; pub mod router;
mod server; pub mod server;
pub mod tree; pub mod tree;
#[pyclass(eq)] #[pyclass(eq)]
...@@ -9,7 +8,7 @@ pub mod tree; ...@@ -9,7 +8,7 @@ pub mod tree;
pub enum PolicyType { pub enum PolicyType {
Random, Random,
RoundRobin, RoundRobin,
ApproxTree, CacheAware,
} }
#[pyclass] #[pyclass]
...@@ -18,8 +17,10 @@ struct Router { ...@@ -18,8 +17,10 @@ struct Router {
port: u16, port: u16,
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: PolicyType, policy: PolicyType,
tokenizer_path: Option<String>, cache_threshold: f32,
cache_threshold: Option<f32>, cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
} }
#[pymethods] #[pymethods]
...@@ -30,33 +31,30 @@ impl Router { ...@@ -30,33 +31,30 @@ impl Router {
policy = PolicyType::RoundRobin, policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"), host = String::from("127.0.0.1"),
port = 3001, port = 3001,
tokenizer_path = None, cache_threshold = 0.50,
cache_threshold = Some(0.50) cache_routing_prob = 1.0,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24)
))] ))]
fn new( fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: PolicyType, policy: PolicyType,
host: String, host: String,
port: u16, port: u16,
tokenizer_path: Option<String>, cache_threshold: f32,
cache_threshold: Option<f32>, cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
) -> PyResult<Self> { ) -> PyResult<Self> {
// Validate required parameters for approx_tree policy
if matches!(policy, PolicyType::ApproxTree) {
if tokenizer_path.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"tokenizer_path is required for approx_tree policy",
));
}
}
Ok(Router { Ok(Router {
host, host,
port, port,
worker_urls, worker_urls,
policy, policy,
tokenizer_path,
cache_threshold, cache_threshold,
cache_routing_prob,
eviction_interval_secs,
max_tree_size,
}) })
} }
...@@ -68,14 +66,11 @@ impl Router { ...@@ -68,14 +66,11 @@ impl Router {
let policy_config = match &self.policy { let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig, PolicyType::Random => router::PolicyConfig::RandomConfig,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig, PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig { PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
tokenizer_path: self cache_threshold: self.cache_threshold,
.tokenizer_path cache_routing_prob: self.cache_routing_prob,
.clone() eviction_interval_secs: self.eviction_interval_secs,
.expect("tokenizer_path is required for approx_tree policy"), max_tree_size: self.max_tree_size,
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
}, },
}; };
......
// src/main.rs // src/main.rs
use clap::Parser; use clap::Parser;
use clap::ValueEnum; use clap::ValueEnum;
// declare child modules
mod router;
mod server;
mod tree;
use crate::router::PolicyConfig; use sglang_router_rs::{router::PolicyConfig, server};
#[derive(Debug, Clone, ValueEnum)] #[derive(Debug, Clone, ValueEnum)]
pub enum PolicyType { pub enum PolicyType {
Random, Random,
RoundRobin, RoundRobin,
ApproxTree, CacheAware,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
...@@ -21,44 +17,70 @@ struct Args { ...@@ -21,44 +17,70 @@ struct Args {
#[arg( #[arg(
long, long,
default_value = "127.0.0.1", default_value = "127.0.0.1",
help = "Host address to bind the server to" help = "Host address to bind the router server to. Default: 127.0.0.1"
)] )]
host: String, host: String,
#[arg(long, default_value_t = 3001, help = "Port number to listen on")] #[arg(
long,
default_value_t = 3001,
help = "Port number to bind the router server to. Default: 3001"
)]
port: u16, port: u16,
#[arg( #[arg(
long, long,
value_delimiter = ',', value_delimiter = ',',
help = "Comma-separated list of worker URLs to distribute requests to" help = "Comma-separated list of worker URLs that will handle the requests. Each URL should include the protocol, host, and port (e.g., http://worker1:8000,http://worker2:8000)"
)] )]
worker_urls: Vec<String>, worker_urls: Vec<String>,
#[arg( #[arg(
long, long,
default_value_t = PolicyType::RoundRobin, default_value_t = PolicyType::CacheAware,
value_enum, value_enum,
help = "Load balancing policy to use: random, round_robin, or approx_tree" help = "Load balancing policy to use for request distribution:\n\
- random: Randomly select workers\n\
- round_robin: Distribute requests in round-robin fashion\n\
- cache_aware: Distribute requests in cache-aware fashion\n"
)] )]
policy: PolicyType, policy: PolicyType,
#[arg( #[arg(
long, long,
default_value_t = 0.5,
requires = "policy",
required_if_eq("policy", "cache_aware"),
help = "Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5"
)]
cache_threshold: f32,
#[arg(
long,
default_value_t = 1.0,
requires = "policy",
required_if_eq("policy", "cache_aware"),
help = "Probability of using cache-aware routing (0.0-1.0). Default 1.0 for full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven workloads, use a lower value to better distribute requests"
)]
cache_routing_prob: f32,
#[arg(
long,
default_value_t = 60,
requires = "policy", requires = "policy",
required_if_eq("policy", "approx_tree"), required_if_eq("policy", "cache_aware"),
help = "Path to the tokenizer file, required when using approx_tree policy" help = "Interval in seconds between cache eviction operations in cache-aware routing. Default: 60"
)] )]
tokenizer_path: Option<String>, eviction_interval_secs: u64,
#[arg( #[arg(
long, long,
default_value = "0.50", default_value_t = 2usize.pow(24),
requires = "policy", requires = "policy",
required_if_eq("policy", "approx_tree"), required_if_eq("policy", "cache_aware"),
help = "Cache threshold (0.0-1.0) for approx_tree routing. Routes to cached worker if match rate exceeds threshold, otherwise routes to shortest queue worker" help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
)] )]
cache_threshold: Option<f32>, max_tree_size: usize,
} }
impl Args { impl Args {
...@@ -66,14 +88,11 @@ impl Args { ...@@ -66,14 +88,11 @@ impl Args {
match self.policy { match self.policy {
PolicyType::Random => PolicyConfig::RandomConfig, PolicyType::Random => PolicyConfig::RandomConfig,
PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig, PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => PolicyConfig::ApproxTreeConfig { PolicyType::CacheAware => PolicyConfig::CacheAwareConfig {
tokenizer_path: self cache_threshold: self.cache_threshold,
.tokenizer_path cache_routing_prob: self.cache_routing_prob,
.clone() eviction_interval_secs: self.eviction_interval_secs,
.expect("tokenizer_path is required for approx_tree policy"), max_tree_size: self.max_tree_size,
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
}, },
} }
} }
......
use crate::tree::RadixTree; use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes; use bytes::Bytes;
use futures_util::TryStreamExt; use futures_util::{Stream, StreamExt, TryStreamExt};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use tokenizers::tokenizer::Tokenizer; use std::thread;
use std::time::Duration;
#[derive(Debug)] #[derive(Debug)]
pub enum Router { pub enum Router {
...@@ -18,34 +21,88 @@ pub enum Router { ...@@ -18,34 +21,88 @@ pub enum Router {
Random { Random {
worker_urls: Vec<String>, worker_urls: Vec<String>,
}, },
ApproxTree { CacheAware {
/*
Cache-Aware Load Balancing Router
This router combines two strategies to optimize both cache utilization and request distribution:
1. Cache-Aware Routing (Approximate Tree)
2. Load Balancing (Shortest Queue)
For each incoming request, the router chooses between these strategies:
- With probability P: Uses cache-aware routing
- With probability (1-P): Uses load balancing
where P is configured via `cache_routing_prob`
Strategy Details:
1. Cache-Aware Routing (Approximate Tree)
-------------------------------------------
This strategy maintains an approximate radix tree for each worker based on request history,
eliminating the need for direct cache state queries. The tree stores raw text characters
instead of token IDs to avoid tokenization overhead.
Process:
a. For each request, find the worker with the highest prefix match
b. If match rate > cache_threshold:
Route to the worker with highest match (likely has relevant data cached)
c. If match rate ≤ cache_threshold:
Route to the worker with smallest tree size (most available cache capacity)
d. Background maintenance:
Periodically evict least recently used leaf nodes to prevent memory overflow
2. Load Balancing (Shortest Queue)
-------------------------------------------
This strategy tracks pending request counts per worker and routes new requests
to the least busy worker for optimal load distribution.
Configuration Parameters:
------------------------
1. cache_routing_prob: (float, 0.0 to 1.0)
- 0.0: Exclusively use load balancing
- 1.0: Exclusively use cache-aware routing
- Between 0-1: Probability of using cache-aware routing vs load balancing
2. cache_threshold: (float, 0.0 to 1.0)
Minimum prefix match ratio to use highest-match routing.
Below this threshold, routes to worker with most available cache space.
3. eviction_interval_secs: (integer)
Interval between LRU eviction cycles for the approximate trees.
4. max_tree_size: (integer)
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Vec<String>, worker_urls: Vec<String>,
// TODO: don't lock the whole tree tree: Arc<Mutex<Tree>>,
url_to_tree: Arc<Mutex<HashMap<String, RadixTree>>>, running_queue: Arc<Mutex<HashMap<String, usize>>>,
tokenizer: Tokenizer, processed_queue: Arc<Mutex<HashMap<String, usize>>>,
url_to_count: Arc<Mutex<HashMap<String, usize>>>,
cache_threshold: f32, cache_threshold: f32,
cache_routing_prob: f32,
_eviction_thread: Option<thread::JoinHandle<()>>, // Store thread handle
}, },
} }
#[derive(Debug)]
pub enum PolicyConfig { pub enum PolicyConfig {
RandomConfig, RandomConfig,
RoundRobinConfig, RoundRobinConfig,
ApproxTreeConfig { CacheAwareConfig {
tokenizer_path: String,
cache_threshold: f32, cache_threshold: f32,
cache_routing_prob: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
}, },
} }
fn get_token_ids_from_request(body: &Bytes, tokenizer: &Tokenizer) -> Vec<u32> { fn get_text_from_request(body: &Bytes) -> String {
// 1. convert body to json // 1. convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap(); let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
// 2. get the text field // 2. get the text field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or(""); let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
// 3. tokenize the text field return text.to_string();
let tokens = tokenizer.encode(text, false).unwrap();
tokens.get_ids().to_vec()
} }
impl Router { impl Router {
...@@ -56,25 +113,56 @@ impl Router { ...@@ -56,25 +113,56 @@ impl Router {
worker_urls, worker_urls,
current_index: std::sync::atomic::AtomicUsize::new(0), current_index: std::sync::atomic::AtomicUsize::new(0),
}, },
PolicyConfig::ApproxTreeConfig { PolicyConfig::CacheAwareConfig {
tokenizer_path,
cache_threshold, cache_threshold,
cache_routing_prob,
eviction_interval_secs,
max_tree_size,
} => { } => {
let mut url_to_tree = HashMap::new(); let mut running_queue = HashMap::new();
let mut url_to_count = HashMap::new(); for url in &worker_urls {
running_queue.insert(url.clone(), 0);
}
let mut processed_queue = HashMap::new();
for url in &worker_urls {
processed_queue.insert(url.clone(), 0);
}
let tree = Arc::new(Mutex::new(Tree::new()));
let running_queue = Arc::new(Mutex::new(running_queue));
let processed_queue = Arc::new(Mutex::new(processed_queue));
// Create background eviction thread
let tree_clone = Arc::clone(&tree);
let processed_queue_clone = Arc::clone(&processed_queue);
let eviction_thread = thread::spawn(move || {
loop {
// Sleep for the specified interval
thread::sleep(Duration::from_secs(eviction_interval_secs));
let locked_tree_clone = tree_clone.lock().unwrap();
// Run eviction
locked_tree_clone.evict_tenant_data(max_tree_size);
// Print the process queue
let locked_processed_queue = processed_queue_clone.lock().unwrap();
println!("Processed Queue: {:?}", locked_processed_queue);
}
});
for url in &worker_urls { for url in &worker_urls {
url_to_tree.insert(url.clone(), RadixTree::new()); tree.lock().unwrap().insert(&"".to_string(), url);
url_to_count.insert(url.clone(), 0);
} }
Router::ApproxTree { Router::CacheAware {
worker_urls, worker_urls,
url_to_tree: Arc::new(Mutex::new(url_to_tree)), tree,
// TODO: rust ::from_pretrained cannot load from local file, so use ::from_file to load local file running_queue,
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(), processed_queue,
url_to_count: Arc::new(Mutex::new(url_to_count)),
cache_threshold, cache_threshold,
cache_routing_prob,
_eviction_thread: Some(eviction_thread),
} }
} }
} }
...@@ -84,7 +172,7 @@ impl Router { ...@@ -84,7 +172,7 @@ impl Router {
match self { match self {
Router::RoundRobin { worker_urls, .. } Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls } | Router::Random { worker_urls }
| Router::ApproxTree { worker_urls, .. } => { | Router::CacheAware { worker_urls, .. } => {
if worker_urls.is_empty() { if worker_urls.is_empty() {
None None
} else { } else {
...@@ -100,10 +188,7 @@ impl Router { ...@@ -100,10 +188,7 @@ impl Router {
req: HttpRequest, req: HttpRequest,
body: Bytes, body: Bytes,
) -> HttpResponse { ) -> HttpResponse {
let mut input_ids: Vec<u32> = Vec::new(); let text = get_text_from_request(&body);
if let Router::ApproxTree { tokenizer, .. } = self {
input_ids = get_token_ids_from_request(&body, tokenizer);
}
let worker_url = match self { let worker_url = match self {
Router::RoundRobin { Router::RoundRobin {
...@@ -125,78 +210,73 @@ impl Router { ...@@ -125,78 +210,73 @@ impl Router {
worker_urls[rand::random::<usize>() % worker_urls.len()].clone() worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
} }
Router::ApproxTree { Router::CacheAware {
worker_urls, worker_urls,
url_to_tree, tree,
url_to_count, running_queue,
processed_queue,
cache_threshold, cache_threshold,
cache_routing_prob,
.. ..
} => { } => {
// TODO: pipeline the locks. Release one earlier. // even though the tree is thread-safe, we still put a lock to ensure the whole op (tree read + queue read + tree write + queue write) is atomic to handle some edge cases (e.g. multiple requests with long prefix entering at the same time)
let mut max_matched_rate = 0.0; let mut tree = tree.lock().unwrap();
let mut max_matched_idx = 0; let mut running_queue = running_queue.lock().unwrap();
let locked_url_to_tree = url_to_tree.lock().unwrap(); // Generate a random float between 0 and 1 for probability check
let sampled_p: f32 = rand::random();
// 1. Find the highest matched worker let selected_url = if sampled_p < *cache_routing_prob {
for (i, url) in worker_urls.iter().enumerate() { // Cache-aware routing logic
let tree = locked_url_to_tree.get(url).unwrap(); let (matched_text, matched_worker) = tree.prefix_match(&text);
let matched = tree.prefix_match(&input_ids[..]).len(); let matched_rate =
let matched_rate = matched as f32 / input_ids.len() as f32; matched_text.chars().count() as f32 / text.chars().count() as f32;
if matched_rate > max_matched_rate { if matched_rate > *cache_threshold {
max_matched_rate = matched_rate; matched_worker.to_string()
max_matched_idx = i; } else {
} let m_map: HashMap<String, usize> = tree
} .tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
// 2. If the rate is higher than the threshold, select the worker. If not, select the worker with the shortest queue println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
if max_matched_rate > *cache_threshold {
worker_urls[max_matched_idx].clone() tree.get_smallest_tenant()
} else {
// pick the shortest queue from url_to_count
let locked_url_to_count = url_to_count.lock().unwrap();
let mut min_count = std::usize::MAX;
let mut min_count_id = 0;
for (i, url) in worker_urls.iter().enumerate() {
let count = locked_url_to_count.get(url).unwrap();
if *count < min_count {
min_count = *count;
min_count_id = i;
}
} }
} else {
// Shortest queue routing logic
running_queue
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone())
};
worker_urls[min_count_id].clone() // Update running queue
} let count = running_queue.get_mut(&selected_url).unwrap();
*count += 1;
// Update processed queue
let mut locked_processed_queue = processed_queue.lock().unwrap();
let count = locked_processed_queue.get_mut(&selected_url).unwrap();
*count += 1;
// Update tree with the new request
tree.insert(&text, &selected_url);
selected_url
} }
}; };
if let Router::ApproxTree {
url_to_tree,
url_to_count,
..
} = self
{
// Insert input_ids to the tree
let mut locked_url_to_tree = url_to_tree.lock().unwrap();
let selected_tree = locked_url_to_tree.get_mut(&worker_url).unwrap();
selected_tree.insert(&input_ids[..]);
let mut locked_url_to_count = url_to_count.lock().unwrap();
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
*count += 1;
}
// Check if client requested streaming
let is_stream = serde_json::from_slice::<serde_json::Value>(&body) let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false); .unwrap_or(false);
let res = match client let res = match client
.post(format!("{}/generate", worker_url)) .post(format!("{}/generate", worker_url.clone()))
.header( .header(
"Content-Type", "Content-Type",
req.headers() req.headers()
...@@ -216,23 +296,53 @@ impl Router { ...@@ -216,23 +296,53 @@ impl Router {
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream { if !is_stream {
// TODO: do the correction on the tree based on the cached input_ids // For non-streaming requests, get response first
if let Router::ApproxTree { url_to_count, .. } = self { let response = match res.bytes().await {
let mut locked_url_to_count = url_to_count.lock().unwrap();
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
*count -= 1;
}
match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()), Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(), Err(_) => HttpResponse::InternalServerError().finish(),
};
// Then decrement running queue counter if using CacheAware
if let Router::CacheAware { running_queue, .. } = self {
if let Ok(mut queue) = running_queue.lock() {
if let Some(count) = queue.get_mut(&worker_url) {
*count = count.saturating_sub(1);
}
}
} }
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
let worker_url = worker_url.clone();
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(
res.bytes_stream()
.map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read stream")
})
.inspect(move |bytes| {
let bytes = bytes.as_ref().unwrap();
if bytes
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]")
{
let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1);
// print
// println!("streaming is done!!")
}
}),
)
} else { } else {
// TODO: do the correction on the tree based on the cached input_ids. The streaming might be tricker to handle
HttpResponse::build(status) HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(res.bytes_stream().map_err(|_| { .streaming(res.bytes_stream().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to read string") actix_web::error::ErrorInternalServerError("Failed to read stream")
})) }))
} }
} }
......
...@@ -76,6 +76,7 @@ pub async fn startup( ...@@ -76,6 +76,7 @@ pub async fn startup(
) -> std::io::Result<()> { ) -> std::io::Result<()> {
println!("Starting server on {}:{}", host, port); println!("Starting server on {}:{}", host, port);
println!("Worker URLs: {:?}", worker_urls); println!("Worker URLs: {:?}", worker_urls);
println!("Policy Config: {:?}", policy_config);
// Create client once with configuration // Create client once with configuration
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
......
This diff is collapsed.
use sglang_router_rs::tree::RadixTree;
#[test]
fn test_new_tree() {
let tree = RadixTree::new();
assert_eq!(tree.root.count, 0);
assert!(tree.root.children.is_empty());
assert!(tree.root.ids.is_empty());
}
#[test]
fn test_single_insertion() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
assert_eq!(tree.root.count, 1);
assert_eq!(tree.root.children.len(), 1);
assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]);
assert_eq!(tree.root.children[&1].count, 1);
}
#[test]
fn test_multiple_insertions_no_split() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.insert(&[4, 5, 6]);
assert_eq!(tree.root.count, 2);
assert_eq!(tree.root.children.len(), 2);
assert_eq!(tree.root.children[&1].ids, vec![1, 2, 3]);
assert_eq!(tree.root.children[&4].ids, vec![4, 5, 6]);
}
#[test]
fn test_insertion_with_split() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
tree.insert(&[1, 2, 5, 6]);
assert_eq!(tree.root.count, 2);
assert_eq!(tree.root.children.len(), 1);
assert_eq!(tree.root.children[&1].ids, vec![1, 2]);
assert_eq!(tree.root.children[&1].children.len(), 2);
assert_eq!(tree.root.children[&1].children[&3].ids, vec![3, 4]);
assert_eq!(tree.root.children[&1].children[&5].ids, vec![5, 6]);
}
#[test]
fn test_prefix_match_exact() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
assert_eq!(tree.prefix_match(&[1, 2, 3, 4]), &[1, 2, 3, 4]);
}
#[test]
fn test_prefix_match_partial() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
assert_eq!(tree.prefix_match(&[1, 2, 3, 5]), &[1, 2, 3]);
assert_eq!(tree.prefix_match(&[1, 2, 5]), &[1, 2]);
assert_eq!(tree.prefix_match(&[1, 5]), &[1]);
}
#[test]
fn test_prefix_match_no_match() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]);
let empty_slices: &[u32] = &[];
assert_eq!(tree.prefix_match(&[5, 6, 7]), empty_slices);
}
#[test]
fn test_delete_leaf() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.delete(&[1, 2, 3]);
assert_eq!(tree.root.count, 0);
assert_eq!(tree.root.children.len(), 0);
}
#[test]
fn test_delete_with_siblings() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.insert(&[1, 2, 4]);
tree.delete(&[1, 2, 3]);
assert_eq!(tree.root.count, 1);
assert_eq!(tree.root.children[&1].children[&4].ids, vec![4]);
}
#[test]
fn test_multiple_operations() {
let mut tree = RadixTree::new();
// Insert several paths
tree.insert(&[1, 2, 3]);
tree.insert(&[1, 2, 4]);
tree.insert(&[1, 5, 6]);
// Verify structure
assert_eq!(tree.root.count, 3);
assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2, 3]);
assert_eq!(tree.prefix_match(&[1, 2, 4]), &[1, 2, 4]);
assert_eq!(tree.prefix_match(&[1, 5, 6]), &[1, 5, 6]);
// Delete and verify
tree.delete(&[1, 2, 3]);
assert_eq!(tree.root.count, 2);
assert_eq!(tree.prefix_match(&[1, 2, 3]), &[1, 2]); // Now only matches prefix
}
#[test]
#[should_panic(expected = "No match found")]
fn test_delete_nonexistent() {
let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3]);
tree.delete(&[4, 5, 6]); // Should panic
}
#[test]
fn test_empty_input() {
let mut tree = RadixTree::new();
let empty_slice: &[u32] = &[];
tree.insert(empty_slice);
assert_eq!(tree.prefix_match(empty_slice), empty_slice);
tree.delete(empty_slice); // Should not panic
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment