Unverified Commit 4b0a1c93 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

Replace prob based with threshold based load balancing (#2170)

parent 8e1adb84
...@@ -25,6 +25,7 @@ import warnings ...@@ -25,6 +25,7 @@ import warnings
from argparse import ArgumentParser from argparse import ArgumentParser
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import aiohttp import aiohttp
...@@ -693,6 +694,19 @@ def gen_prompt(tokenizer, token_num): ...@@ -693,6 +694,19 @@ def gen_prompt(tokenizer, token_num):
return tokenizer.decode(selected_tokens) return tokenizer.decode(selected_tokens)
def get_gen_prefix_cache_path(args, tokenizer):
"""Create cache directory under ~/.cache/sglang/benchmark"""
cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
# Create a unique cache filename based on the generation parameters
cache_key = (
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
f"{tokenizer.__class__.__name__}.pkl"
)
return cache_dir / cache_key
def sample_generated_shared_prefix_requests( def sample_generated_shared_prefix_requests(
num_groups: int, num_groups: int,
prompts_per_group: int, prompts_per_group: int,
...@@ -701,12 +715,17 @@ def sample_generated_shared_prefix_requests( ...@@ -701,12 +715,17 @@ def sample_generated_shared_prefix_requests(
output_len: int, output_len: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
if args.generated_input_path and os.path.exists(args.generated_input_path): """Generate benchmark requests with shared system prompts using random tokens and caching."""
print(f"\nloading generated input data from {args.generated_input_path}") cache_path = get_gen_prefix_cache_path(args, tokenizer)
with open(args.generated_input_path, "rb") as f:
# Try to load from cache first
if cache_path.exists():
print(f"\nLoading cached generated input data from {cache_path}")
with open(cache_path, "rb") as f:
return pickle.load(f) return pickle.load(f)
"""Generate benchmark requests with shared system prompts using random tokens.""" print("\nGenerating new input data...")
# Generate system prompts for each group # Generate system prompts for each group
system_prompts = [] system_prompts = []
for _ in range(num_groups): for _ in range(num_groups):
...@@ -719,9 +738,6 @@ def sample_generated_shared_prefix_requests( ...@@ -719,9 +738,6 @@ def sample_generated_shared_prefix_requests(
question = gen_prompt(tokenizer, question_len) question = gen_prompt(tokenizer, question_len)
questions.append(question) questions.append(question)
# Shuffle questions
random.shuffle(questions)
# Combine system prompts with questions # Combine system prompts with questions
input_requests = [] input_requests = []
total_input_tokens = 0 total_input_tokens = 0
...@@ -729,7 +745,9 @@ def sample_generated_shared_prefix_requests( ...@@ -729,7 +745,9 @@ def sample_generated_shared_prefix_requests(
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"): for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
system_prompt = system_prompts[group_idx] system_prompt = system_prompts[group_idx]
for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions"): for prompt_idx in tqdm(
range(prompts_per_group), desc="Generating questions", leave=False
):
question = questions[group_idx * prompts_per_group + prompt_idx] question = questions[group_idx * prompts_per_group + prompt_idx]
full_prompt = f"{system_prompt}\n\n{question}" full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt)) prompt_len = len(tokenizer.encode(full_prompt))
...@@ -738,6 +756,10 @@ def sample_generated_shared_prefix_requests( ...@@ -738,6 +756,10 @@ def sample_generated_shared_prefix_requests(
total_input_tokens += prompt_len total_input_tokens += prompt_len
total_output_tokens += output_len total_output_tokens += output_len
# Shuffle questions
random.shuffle(input_requests)
# Print statistics
print(f"\nGenerated shared prefix dataset statistics:") print(f"\nGenerated shared prefix dataset statistics:")
print(f"Number of groups: {num_groups}") print(f"Number of groups: {num_groups}")
print(f"Prompts per group: {prompts_per_group}") print(f"Prompts per group: {prompts_per_group}")
...@@ -750,11 +772,12 @@ def sample_generated_shared_prefix_requests( ...@@ -750,11 +772,12 @@ def sample_generated_shared_prefix_requests(
print( print(
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n" f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
) )
if args.generated_input_save_path:
print(f"Saving generated input data to {args.generated_input_save_path}") # Save to cache
os.makedirs(os.path.dirname(args.generated_input_save_path), exist_ok=True) cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(args.generated_input_save_path, "wb") as f: print(f"Caching generated input data to {cache_path}")
pickle.dump(input_requests, f) with open(cache_path, "wb") as f:
pickle.dump(input_requests, f)
return input_requests return input_requests
...@@ -1422,16 +1445,6 @@ if __name__ == "__main__": ...@@ -1422,16 +1445,6 @@ if __name__ == "__main__":
default=256, default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset", help="Target length in tokens for outputs in generated-shared-prefix dataset",
) )
parser.add_argument(
"--generated-input-save-path",
type=str,
help="Path to save generated input data",
)
parser.add_argument(
"--generated-input-path",
type=str,
help="Path to load previously generated input data",
)
parser.add_argument( parser.add_argument(
"--profile", "--profile",
action="store_true", action="store_true",
......
...@@ -20,33 +20,35 @@ $ python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3.1-8 ...@@ -20,33 +20,35 @@ $ python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3.1-8
``` ```
### 2. Launch only router ### 2. Launch only router
This is useful if you for multi node DP. You can launch workers on different nodes, then connect the router to them. This is useful for multi-node DP. You can launch workers on different nodes, then connect the router to them.
```bash ```bash
$ python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 $ python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
$ python -m sglang_router.launch_router --help $ python -m sglang_router.launch_router --help
usage: launch_router.py [-h] [--host HOST] [--port PORT] [--worker-urls WORKER_URLS [WORKER_URLS ...]] usage: launch_router.py [-h] [--host HOST] [--port PORT] [--worker-urls WORKER_URLS [WORKER_URLS ...]]
[--policy {random,round_robin,cache_aware}] [--cache-threshold CACHE_THRESHOLD] [--policy {random,round_robin,cache_aware}] [--cache-threshold CACHE_THRESHOLD]
[--cache-routing-prob CACHE_ROUTING_PROB] [--eviction-interval EVICTION_INTERVAL] [--balance-abs-threshold BALANCE_ABS_THRESHOLD] [--balance-rel-threshold BALANCE_REL_THRESHOLD]
[--max-tree-size MAX_TREE_SIZE] [--eviction-interval EVICTION_INTERVAL] [--max-tree-size MAX_TREE_SIZE]
options: options:
-h, --help show this help message and exit -h, --help show this help message and exit
--host HOST Host address to bind the router server (default: 127.0.0.1) --host HOST Host address to bind the router server (default: 127.0.0.1)
--port PORT Port number to bind the router server (default: 30000) --port PORT Port number to bind the router server (default: 30000)
--worker-urls WORKER_URLS [WORKER_URLS ...] --worker-urls WORKER_URLS [WORKER_URLS ...]
List of worker URLs (e.g., http://worker1:8000 http://worker2:8000) (default: None) List of worker URLs (e.g., http://worker1:8000 http://worker2:8000) (default: None)
--policy {random,round_robin,cache_aware} --policy {random,round_robin,cache_aware}
Load balancing policy to use (default: cache_aware) Load balancing policy to use (default: cache_aware)
--cache-threshold CACHE_THRESHOLD --cache-threshold CACHE_THRESHOLD
Cache threshold (0.0-1.0) for cache-aware routing (default: 0.5) Cache threshold (0.0-1.0) for cache-aware routing (default: 0.5)
--cache-routing-prob CACHE_ROUTING_PROB --balance-abs-threshold BALANCE_ABS_THRESHOLD
Probability of using cache-aware routing (0.0-1.0) (default: 1.0) Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold (default: 32)
--balance-rel-threshold BALANCE_REL_THRESHOLD
Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold (default: 1.0001)
--eviction-interval EVICTION_INTERVAL --eviction-interval EVICTION_INTERVAL
Interval in seconds between cache eviction operations (default: 60) Interval in seconds between cache eviction operations (default: 60)
--max-tree-size MAX_TREE_SIZE --max-tree-size MAX_TREE_SIZE
Maximum size of the approximation tree for cache-aware routing (default: 16777216) Maximum size of the approximation tree for cache-aware routing (default: 16777216)
``` ```
## Strategy ## Strategy
...@@ -56,7 +58,15 @@ options: ...@@ -56,7 +58,15 @@ options:
This router combines two strategies to optimize both cache utilization and request distribution: This router combines two strategies to optimize both cache utilization and request distribution:
1. Cache-Aware Routing (Approximate Tree) 1. Cache-Aware Routing (Approximate Tree)
2. Load-Balancing Routing (Shortest Queue) 2. Load-Balancing Routing (Shortest Queue with Balance Thresholds)
The router dynamically switches between these strategies based on load conditions:
- Uses load balancing when the system is imbalanced
- Uses cache-aware routing when the system is balanced
A system is considered imbalanced if both conditions are met:
1. (max_load - min_load) > balance_abs_threshold
2. max_load > balance_rel_threshold * min_load
#### 1. Cache-Aware Routing (Approximate Tree) #### 1. Cache-Aware Routing (Approximate Tree)
This strategy maintains an approximate radix tree for each worker based on request history, This strategy maintains an approximate radix tree for each worker based on request history,
...@@ -74,27 +84,32 @@ Process: ...@@ -74,27 +84,32 @@ Process:
#### 2. Load-Balancing (Shortest Queue) #### 2. Load-Balancing (Shortest Queue)
This strategy tracks pending request counts per worker and routes new requests This strategy tracks pending request counts per worker and routes new requests
to the least busy worker for optimal load distribution. to the least busy worker when the system is detected to be imbalanced. This helps
maintain optimal load distribution across workers.
### Configuration Parameters ### Configuration Parameters
1. `cache_routing_prob`: (float, 0.0 to 1.0) 1. `cache_threshold`: (float, 0.0 to 1.0, default: 0.5)
- 0.0: Exclusively use load balancing
- 1.0: Exclusively use cache-aware routing
- Between 0-1: Probability of using cache-aware routing vs load balancing
2. `cache_threshold`: (float, 0.0 to 1.0)
- Minimum prefix match ratio to use highest-match routing - Minimum prefix match ratio to use highest-match routing
- Below this threshold, routes to worker with most available cache space - Below this threshold, routes to worker with most available cache space
3. `eviction_interval_secs`: (integer) 2. `balance_abs_threshold`: (integer, default: 32)
- Interval between LRU eviction cycles for the approximate trees - Absolute difference threshold for load imbalance detection
- System is potentially imbalanced if (max_load - min_load) > abs_threshold
4. `max_tree_size`: (integer) 3. `balance_rel_threshold`: (float, default: 1.0001)
- Relative ratio threshold for load imbalance detection
- System is potentially imbalanced if max_load > min_load * rel_threshold
- Used in conjunction with abs_threshold to determine final imbalance state
4. `eviction_interval`: (integer, default: 60)
- Interval in seconds between LRU eviction cycles for the approximate trees
- Background thread periodically evicts least recently used nodes to maintain tree size
5. `max_tree_size`: (integer, default: 16777216)
- Maximum nodes per tree - Maximum nodes per tree
- When exceeded, LRU leaf nodes are evicted during the next eviction cycle - When exceeded, LRU leaf nodes are evicted during the next eviction cycle
## Development ## Development
- Rust and Cargo installed - Rust and Cargo installed
......
...@@ -17,7 +17,8 @@ class RouterArgs: ...@@ -17,7 +17,8 @@ class RouterArgs:
# Routing policy # Routing policy
policy: str = "cache_aware" policy: str = "cache_aware"
cache_threshold: float = 0.5 cache_threshold: float = 0.5
cache_routing_prob: float = 1.0 balance_abs_threshold: int = 32
balance_rel_threshold: float = 1.0001
eviction_interval: int = 60 eviction_interval: int = 60
max_tree_size: int = 2**24 max_tree_size: int = 2**24
...@@ -74,10 +75,16 @@ class RouterArgs: ...@@ -74,10 +75,16 @@ class RouterArgs:
help="Cache threshold (0.0-1.0) for cache-aware routing", help="Cache threshold (0.0-1.0) for cache-aware routing",
) )
parser.add_argument( parser.add_argument(
f"--{prefix}cache-routing-prob", f"--{prefix}balance-abs-threshold",
type=int,
default=RouterArgs.balance_abs_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}balance-rel-threshold",
type=float, type=float,
default=RouterArgs.cache_routing_prob, default=RouterArgs.balance_rel_threshold,
help="Probability of using cache-aware routing (0.0-1.0)", help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
) )
parser.add_argument( parser.add_argument(
f"--{prefix}eviction-interval", f"--{prefix}eviction-interval",
...@@ -110,7 +117,8 @@ class RouterArgs: ...@@ -110,7 +117,8 @@ class RouterArgs:
port=args.port, port=args.port,
policy=getattr(args, f"{prefix}policy"), policy=getattr(args, f"{prefix}policy"),
cache_threshold=getattr(args, f"{prefix}cache_threshold"), cache_threshold=getattr(args, f"{prefix}cache_threshold"),
cache_routing_prob=getattr(args, f"{prefix}cache_routing_prob"), balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"), eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"), max_tree_size=getattr(args, f"{prefix}max_tree_size"),
) )
...@@ -150,7 +158,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -150,7 +158,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
host=router_args.host, host=router_args.host,
port=router_args.port, port=router_args.port,
cache_threshold=router_args.cache_threshold, cache_threshold=router_args.cache_threshold,
cache_routing_prob=router_args.cache_routing_prob, balance_abs_threshold=router_args.balance_abs_threshold,
balance_rel_threshold=router_args.balance_rel_threshold,
eviction_interval_secs=router_args.eviction_interval, eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size, max_tree_size=router_args.max_tree_size,
) )
...@@ -182,7 +191,7 @@ multi-node setups or when you want to start workers and router separately. ...@@ -182,7 +191,7 @@ multi-node setups or when you want to start workers and router separately.
Examples: Examples:
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --cache-routing-prob 0.5 python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
""", """,
formatter_class=CustomHelpFormatter, formatter_class=CustomHelpFormatter,
......
...@@ -14,15 +14,16 @@ class Router: ...@@ -14,15 +14,16 @@ class Router:
policy: Load balancing policy to use. Options: policy: Load balancing policy to use. Options:
- PolicyType.Random: Randomly select workers - PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion - PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.CacheAware: Distribute requests in cache-aware fashion - PolicyType.CacheAware: Distribute requests based on cache state and load balance
host: Host address to bind the router server. Default: '127.0.0.1' host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001 port: Port number to bind the router server. Default: 3001
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
if the match rate exceeds threshold, otherwise routes to the worker with the smallest if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5 tree. Default: 0.5
cache_routing_prob: Probability of using cache-aware routing (0.0-1.0). Default 1.0 for balance_abs_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32
workloads, use a lower value to better distribute requests balance_rel_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60 routing. Default: 60
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24 max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
...@@ -35,7 +36,8 @@ class Router: ...@@ -35,7 +36,8 @@ class Router:
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 3001, port: int = 3001,
cache_threshold: float = 0.50, cache_threshold: float = 0.50,
cache_routing_prob: float = 1.0, balance_abs_threshold: int = 32,
balance_rel_threshold: float = 1.0001,
eviction_interval_secs: int = 60, eviction_interval_secs: int = 60,
max_tree_size: int = 2**24, max_tree_size: int = 2**24,
): ):
...@@ -45,7 +47,8 @@ class Router: ...@@ -45,7 +47,8 @@ class Router:
host=host, host=host,
port=port, port=port,
cache_threshold=cache_threshold, cache_threshold=cache_threshold,
cache_routing_prob=cache_routing_prob, balance_abs_threshold=balance_abs_threshold,
balance_rel_threshold=balance_rel_threshold,
eviction_interval_secs=eviction_interval_secs, eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size, max_tree_size=max_tree_size,
) )
......
...@@ -18,7 +18,8 @@ struct Router { ...@@ -18,7 +18,8 @@ struct Router {
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: PolicyType, policy: PolicyType,
cache_threshold: f32, cache_threshold: f32,
cache_routing_prob: f32, balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64, eviction_interval_secs: u64,
max_tree_size: usize, max_tree_size: usize,
} }
...@@ -32,7 +33,8 @@ impl Router { ...@@ -32,7 +33,8 @@ impl Router {
host = String::from("127.0.0.1"), host = String::from("127.0.0.1"),
port = 3001, port = 3001,
cache_threshold = 0.50, cache_threshold = 0.50,
cache_routing_prob = 1.0, balance_abs_threshold = 32,
balance_rel_threshold = 1.0001,
eviction_interval_secs = 60, eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24) max_tree_size = 2usize.pow(24)
))] ))]
...@@ -42,7 +44,8 @@ impl Router { ...@@ -42,7 +44,8 @@ impl Router {
host: String, host: String,
port: u16, port: u16,
cache_threshold: f32, cache_threshold: f32,
cache_routing_prob: f32, balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64, eviction_interval_secs: u64,
max_tree_size: usize, max_tree_size: usize,
) -> PyResult<Self> { ) -> PyResult<Self> {
...@@ -52,7 +55,8 @@ impl Router { ...@@ -52,7 +55,8 @@ impl Router {
worker_urls, worker_urls,
policy, policy,
cache_threshold, cache_threshold,
cache_routing_prob, balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs, eviction_interval_secs,
max_tree_size, max_tree_size,
}) })
...@@ -68,7 +72,8 @@ impl Router { ...@@ -68,7 +72,8 @@ impl Router {
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig, PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
cache_threshold: self.cache_threshold, cache_threshold: self.cache_threshold,
cache_routing_prob: self.cache_routing_prob, balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs, eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size, max_tree_size: self.max_tree_size,
}, },
......
// src/main.rs
use clap::Parser; use clap::Parser;
use clap::ValueEnum; use clap::ValueEnum;
...@@ -42,7 +41,7 @@ struct Args { ...@@ -42,7 +41,7 @@ struct Args {
help = "Load balancing policy to use for request distribution:\n\ help = "Load balancing policy to use for request distribution:\n\
- random: Randomly select workers\n\ - random: Randomly select workers\n\
- round_robin: Distribute requests in round-robin fashion\n\ - round_robin: Distribute requests in round-robin fashion\n\
- cache_aware: Distribute requests in cache-aware fashion\n" - cache_aware: Distribute requests based on cache state and load balance\n"
)] )]
policy: PolicyType, policy: PolicyType,
...@@ -57,12 +56,21 @@ struct Args { ...@@ -57,12 +56,21 @@ struct Args {
#[arg( #[arg(
long, long,
default_value_t = 1.0, default_value_t = 32,
requires = "policy", requires = "policy",
required_if_eq("policy", "cache_aware"), required_if_eq("policy", "cache_aware"),
help = "Probability of using cache-aware routing (0.0-1.0). Default 1.0 for full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven workloads, use a lower value to better distribute requests" help = "Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32"
)] )]
cache_routing_prob: f32, balance_abs_threshold: usize,
#[arg(
long,
default_value_t = 1.0001,
requires = "policy",
required_if_eq("policy", "cache_aware"),
help = "Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001"
)]
balance_rel_threshold: f32,
#[arg( #[arg(
long, long,
...@@ -90,7 +98,8 @@ impl Args { ...@@ -90,7 +98,8 @@ impl Args {
PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig, PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
PolicyType::CacheAware => PolicyConfig::CacheAwareConfig { PolicyType::CacheAware => PolicyConfig::CacheAwareConfig {
cache_threshold: self.cache_threshold, cache_threshold: self.cache_threshold,
cache_routing_prob: self.cache_routing_prob, balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs, eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size, max_tree_size: self.max_tree_size,
}, },
......
...@@ -23,65 +23,73 @@ pub enum Router { ...@@ -23,65 +23,73 @@ pub enum Router {
}, },
CacheAware { CacheAware {
/* /*
Cache-Aware Load Balancing Router Cache-Aware Load Balancing Router
This router combines two strategies to optimize both cache utilization and request distribution: This router combines two strategies to optimize both cache utilization and request distribution:
1. Cache-Aware Routing (Approximate Tree) 1. Cache-Aware Routing (Approximate Tree)
2. Load Balancing (Shortest Queue) 2. Load Balancing (Shortest Queue with Balance Thresholds)
For each incoming request, the router chooses between these strategies: The router dynamically switches between these strategies based on load conditions:
- With probability P: Uses cache-aware routing - Uses load balancing when the system is imbalanced
- With probability (1-P): Uses load balancing - Uses cache-aware routing when the system is balanced
where P is configured via `cache_routing_prob`
A system is considered imbalanced if both conditions are met:
Strategy Details: 1. (max - min) > abs_threshold
2. max > rel_threshold * min
1. Cache-Aware Routing (Approximate Tree)
------------------------------------------- Strategy Details:
This strategy maintains an approximate radix tree for each worker based on request history,
eliminating the need for direct cache state queries. The tree stores raw text characters 1. Cache-Aware Routing (Approximate Tree)
instead of token IDs to avoid tokenization overhead. -------------------------------------------
This strategy maintains an approximate radix tree for each worker based on request history,
Process: eliminating the need for direct cache state queries. The tree stores raw text characters
a. For each request, find the worker with the highest prefix match instead of token IDs to avoid tokenization overhead.
b. If match rate > cache_threshold:
Route to the worker with highest match (likely has relevant data cached) Process:
c. If match rate ≤ cache_threshold: a. For each request, find the worker with the highest prefix match
Route to the worker with smallest tree size (most available cache capacity) b. If match rate > cache_threshold:
d. Background maintenance: Route to the worker with highest match (likely has relevant data cached)
Periodically evict least recently used leaf nodes to prevent memory overflow c. If match rate ≤ cache_threshold:
Route to the worker with smallest tree size (most available cache capacity)
2. Load Balancing (Shortest Queue) d. Background maintenance:
------------------------------------------- Periodically evict least recently used leaf nodes to prevent memory overflow
This strategy tracks pending request counts per worker and routes new requests
to the least busy worker for optimal load distribution. 2. Load Balancing (Shortest Queue)
-------------------------------------------
Configuration Parameters: This strategy tracks pending request counts per worker and routes new requests
------------------------ to the least busy worker when the system is detected to be imbalanced.
1. cache_routing_prob: (float, 0.0 to 1.0)
- 0.0: Exclusively use load balancing Configuration Parameters:
- 1.0: Exclusively use cache-aware routing ------------------------
- Between 0-1: Probability of using cache-aware routing vs load balancing 1. cache_threshold: (float, 0.0 to 1.0)
Minimum prefix match ratio to use highest-match routing.
2. cache_threshold: (float, 0.0 to 1.0) Below this threshold, routes to worker with most available cache space.
Minimum prefix match ratio to use highest-match routing.
Below this threshold, routes to worker with most available cache space. 2. balance_abs_threshold: (integer)
Absolute difference threshold for load imbalance detection.
3. eviction_interval_secs: (integer) System is potentially imbalanced if (max_load - min_load) > abs_threshold
Interval between LRU eviction cycles for the approximate trees.
3. balance_rel_threshold: (float)
4. max_tree_size: (integer) Relative ratio threshold for load imbalance detection.
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted System is potentially imbalanced if max_load > min_load * rel_threshold
during the next eviction cycle. Used in conjunction with abs_threshold to determine final imbalance state.
4. eviction_interval_secs: (integer)
Interval between LRU eviction cycles for the approximate trees.
5. max_tree_size: (integer)
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/ */
worker_urls: Vec<String>, worker_urls: Vec<String>,
tree: Arc<Mutex<Tree>>, tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>, running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>, processed_queue: Arc<Mutex<HashMap<String, usize>>>,
cache_threshold: f32, cache_threshold: f32,
cache_routing_prob: f32, balance_abs_threshold: usize,
_eviction_thread: Option<thread::JoinHandle<()>>, // Store thread handle balance_rel_threshold: f32,
_eviction_thread: Option<thread::JoinHandle<()>>,
}, },
} }
...@@ -91,7 +99,8 @@ pub enum PolicyConfig { ...@@ -91,7 +99,8 @@ pub enum PolicyConfig {
RoundRobinConfig, RoundRobinConfig,
CacheAwareConfig { CacheAwareConfig {
cache_threshold: f32, cache_threshold: f32,
cache_routing_prob: f32, balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64, eviction_interval_secs: u64,
max_tree_size: usize, max_tree_size: usize,
}, },
...@@ -128,7 +137,8 @@ impl Router { ...@@ -128,7 +137,8 @@ impl Router {
}, },
PolicyConfig::CacheAwareConfig { PolicyConfig::CacheAwareConfig {
cache_threshold, cache_threshold,
cache_routing_prob, balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs, eviction_interval_secs,
max_tree_size, max_tree_size,
} => { } => {
...@@ -149,6 +159,7 @@ impl Router { ...@@ -149,6 +159,7 @@ impl Router {
// Create background eviction thread // Create background eviction thread
let tree_clone = Arc::clone(&tree); let tree_clone = Arc::clone(&tree);
let processed_queue_clone = Arc::clone(&processed_queue); let processed_queue_clone = Arc::clone(&processed_queue);
let running_queue_clone = Arc::clone(&running_queue);
let eviction_thread = thread::spawn(move || { let eviction_thread = thread::spawn(move || {
loop { loop {
// Sleep for the specified interval // Sleep for the specified interval
...@@ -161,6 +172,10 @@ impl Router { ...@@ -161,6 +172,10 @@ impl Router {
// Print the process queue // Print the process queue
let locked_processed_queue = processed_queue_clone.lock().unwrap(); let locked_processed_queue = processed_queue_clone.lock().unwrap();
println!("Processed Queue: {:?}", locked_processed_queue); println!("Processed Queue: {:?}", locked_processed_queue);
// Print the running queue
let locked_running_queue = running_queue_clone.lock().unwrap();
println!("Running Queue: {:?}", locked_running_queue);
} }
}); });
...@@ -174,7 +189,8 @@ impl Router { ...@@ -174,7 +189,8 @@ impl Router {
running_queue, running_queue,
processed_queue, processed_queue,
cache_threshold, cache_threshold,
cache_routing_prob, balance_abs_threshold,
balance_rel_threshold,
_eviction_thread: Some(eviction_thread), _eviction_thread: Some(eviction_thread),
} }
} }
...@@ -203,8 +219,6 @@ impl Router { ...@@ -203,8 +219,6 @@ impl Router {
route: &str, route: &str,
) -> HttpResponse { ) -> HttpResponse {
let text = get_text_from_request(&body, route); let text = get_text_from_request(&body, route);
// For Debug
// println!("text: {:?}, route: {:?}", text, route);
let worker_url = match self { let worker_url = match self {
Router::RoundRobin { Router::RoundRobin {
...@@ -218,7 +232,6 @@ impl Router { ...@@ -218,7 +232,6 @@ impl Router {
|x| Some((x + 1) % worker_urls.len()), |x| Some((x + 1) % worker_urls.len()),
) )
.unwrap(); .unwrap();
worker_urls[idx].clone() worker_urls[idx].clone()
} }
...@@ -232,19 +245,42 @@ impl Router { ...@@ -232,19 +245,42 @@ impl Router {
running_queue, running_queue,
processed_queue, processed_queue,
cache_threshold, cache_threshold,
cache_routing_prob, balance_abs_threshold,
balance_rel_threshold,
.. ..
} => { } => {
// even though the tree is thread-safe, we still put a lock to ensure the whole op (tree read + queue read + tree write + queue write) is atomic to handle some edge cases (e.g. multiple requests with long prefix entering at the same time) // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
let mut tree = tree.lock().unwrap(); let mut tree = tree.lock().unwrap();
let mut running_queue = running_queue.lock().unwrap(); let mut running_queue = running_queue.lock().unwrap();
// Generate a random float between 0 and 1 for probability check // Get current load statistics
let sampled_p: f32 = rand::random(); let max_load = *running_queue.values().max().unwrap_or(&0);
let min_load = *running_queue.values().min().unwrap_or(&0);
let selected_url = if sampled_p < *cache_routing_prob {
// Cache-aware routing logic // Load is considered imbalanced if:
// 1. (max - min) > abs_threshold AND
// 2. max > rel_threshold * min
let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
&& (max_load as f32) > (min_load as f32 * balance_rel_threshold);
let selected_url = if is_imbalanced {
// Log load balancing trigger and current queue state
println!(
"Load balancing triggered due to workload imbalance:\n\
Max load: {}, Min load: {}\n\
Current running queue: {:?}",
max_load, min_load, running_queue
);
// Use shortest queue routing when load is imbalanced
running_queue
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone())
} else {
// Use cache-aware routing when load is balanced
let (matched_text, matched_worker) = tree.prefix_match(&text); let (matched_text, matched_worker) = tree.prefix_match(&text);
let matched_rate = let matched_rate =
matched_text.chars().count() as f32 / text.chars().count() as f32; matched_text.chars().count() as f32 / text.chars().count() as f32;
...@@ -252,36 +288,18 @@ impl Router { ...@@ -252,36 +288,18 @@ impl Router {
if matched_rate > *cache_threshold { if matched_rate > *cache_threshold {
matched_worker.to_string() matched_worker.to_string()
} else { } else {
// For Debug
// let m_map: HashMap<String, usize> = tree
// .tenant_char_count
// .iter()
// .map(|entry| (entry.key().clone(), *entry.value()))
// .collect();
// println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
tree.get_smallest_tenant() tree.get_smallest_tenant()
} }
} else {
// Shortest queue routing logic
running_queue
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone())
}; };
// Update running queue // Update queues and tree
let count = running_queue.get_mut(&selected_url).unwrap(); *running_queue.get_mut(&selected_url).unwrap() += 1;
*count += 1;
// Update processed queue
let mut locked_processed_queue = processed_queue.lock().unwrap();
let count = locked_processed_queue.get_mut(&selected_url).unwrap();
*count += 1;
// Update tree with the new request *processed_queue
.lock()
.unwrap()
.get_mut(&selected_url)
.unwrap() += 1;
tree.insert(&text, &selected_url); tree.insert(&text, &selected_url);
selected_url selected_url
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment