Unverified Commit 30af7dfb authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] add base_gpu_id server args & merged radix tree python reference (#2115)

parent f6f71379
...@@ -156,7 +156,7 @@ class DataParallelController: ...@@ -156,7 +156,7 @@ class DataParallelController:
) )
for tp_rank in tp_rank_range: for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False) reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id + tp_rank % tp_size_per_node gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process( proc = mp.Process(
target=run_scheduler_process, target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
......
...@@ -1380,6 +1380,10 @@ def run_scheduler_process( ...@@ -1380,6 +1380,10 @@ def run_scheduler_process(
dp_rank: Optional[int], dp_rank: Optional[int],
pipe_writer, pipe_writer,
): ):
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None:
dp_rank = int(os.getenv("DP_RANK", -1))
if dp_rank is None: if dp_rank is None:
configure_logger(server_args, prefix=f" TP{tp_rank}") configure_logger(server_args, prefix=f" TP{tp_rank}")
else: else:
......
...@@ -418,7 +418,7 @@ def launch_engine( ...@@ -418,7 +418,7 @@ def launch_engine(
) )
for tp_rank in tp_rank_range: for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False) reader, writer = mp.Pipe(duplex=False)
gpu_id = tp_rank % tp_size_per_node gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process( proc = mp.Process(
target=run_scheduler_process, target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer), args=(server_args, port_args, gpu_id, tp_rank, None, writer),
......
...@@ -72,6 +72,7 @@ class ServerArgs: ...@@ -72,6 +72,7 @@ class ServerArgs:
constrained_json_whitespace_pattern: Optional[str] = None constrained_json_whitespace_pattern: Optional[str] = None
watchdog_timeout: float = 300 watchdog_timeout: float = 300
download_dir: Optional[str] = None download_dir: Optional[str] = None
base_gpu_id: int = 0
# Logging # Logging
log_level: str = "info" log_level: str = "info"
...@@ -412,6 +413,12 @@ class ServerArgs: ...@@ -412,6 +413,12 @@ class ServerArgs:
default=ServerArgs.download_dir, default=ServerArgs.download_dir,
help="Model download directory.", help="Model download directory.",
) )
parser.add_argument(
"--base-gpu-id",
type=int,
default=ServerArgs.base_gpu_id,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
)
# Logging # Logging
parser.add_argument( parser.add_argument(
...@@ -736,6 +743,7 @@ class ServerArgs: ...@@ -736,6 +743,7 @@ class ServerArgs:
and (self.lora_paths is None or self.disable_cuda_graph) and (self.lora_paths is None or self.disable_cuda_graph)
and (self.lora_paths is None or self.disable_radix_cache) and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress" ), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
if isinstance(self.lora_paths, list): if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths lora_paths = self.lora_paths
......
import random
import string
import time
import unittest
from typing import Dict, List, Tuple
from tree import MultiTenantRadixTree
class TestMultiTenantRadixTree(unittest.TestCase):
def setUp(self):
self.tree = MultiTenantRadixTree()
def test_insert_exact_match(self):
"""Test 1: Basic insert and exact match operations"""
# Insert a single string for one tenant
self.tree.insert("hello", "tenant1")
matched, tenant = self.tree.prefix_match("hello")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
# Insert same string for different tenant
self.tree.insert("hello", "tenant2")
matched, tenant = self.tree.prefix_match("hello")
self.assertIn(tenant, ["tenant1", "tenant2"])
# Insert different string for same tenant
self.tree.insert("world", "tenant1")
matched, tenant = self.tree.prefix_match("world")
self.assertEqual(matched, "world")
self.assertEqual(tenant, "tenant1")
print(self.tree.pretty_print())
def test_insert_partial_match(self):
"""Test 2: Insert with partial matching scenarios"""
# Test partial matches with common prefixes
self.tree.insert("hello", "tenant1")
print(self.tree.pretty_print())
self.tree.insert("help", "tenant2")
print(self.tree.pretty_print())
# Match exact strings
matched, tenant = self.tree.prefix_match("hello")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
matched, tenant = self.tree.prefix_match("help")
self.assertEqual(matched, "help")
self.assertEqual(tenant, "tenant2")
# Match partial string
matched, tenant = self.tree.prefix_match("hel")
self.assertEqual(matched, "hel")
self.assertIn(tenant, ["tenant1", "tenant2"])
# Match longer string
matched, tenant = self.tree.prefix_match("hello_world")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
def test_insert_edge_cases(self):
"""Test 3: Edge cases for insert and match operations"""
# Empty string
self.tree.insert("", "tenant1")
matched, tenant = self.tree.prefix_match("")
self.assertEqual(matched, "")
self.assertEqual(tenant, "tenant1")
# Single character
self.tree.insert("a", "tenant1")
matched, tenant = self.tree.prefix_match("a")
self.assertEqual(matched, "a")
self.assertEqual(tenant, "tenant1")
# Very long string
long_str = "a" * 1000
self.tree.insert(long_str, "tenant1")
matched, tenant = self.tree.prefix_match(long_str)
self.assertEqual(matched, long_str)
self.assertEqual(tenant, "tenant1")
# Unicode characters
self.tree.insert("你好", "tenant1")
matched, tenant = self.tree.prefix_match("你好")
self.assertEqual(matched, "你好")
self.assertEqual(tenant, "tenant1")
def test_simple_eviction(self):
"""Test 4: Simple eviction scenarios
Tenant1: limit 10 chars
Tenant2: limit 5 chars
Should demonstrate:
1. Basic eviction when size limit exceeded
2. Proper eviction based on last access time
3. Verification that shared nodes remain intact for other tenants
"""
# Set up size limits
max_size = {"tenant1": 10, "tenant2": 5}
# Insert strings for both tenants
self.tree.insert("hello", "tenant1") # size 5
self.tree.insert("hello", "tenant2") # size 5
self.tree.insert("world", "tenant2") # size 5, total for tenant2 = 10
# Verify initial sizes
sizes_before = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_before["tenant1"], 5) # "hello" = 5
self.assertEqual(sizes_before["tenant2"], 10) # "hello" + "world" = 10
# Evict - should remove "hello" from tenant2 as it's the oldest
self.tree.evict_tenant_data(max_size)
# Verify sizes after eviction
sizes_after = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_after["tenant1"], 5) # Should be unchanged
self.assertEqual(sizes_after["tenant2"], 5) # Only "world" remains
# Verify "world" remains for tenant2 (was accessed more recently)
matched, tenant = self.tree.prefix_match("world")
self.assertEqual(matched, "world")
self.assertEqual(tenant, "tenant2")
def test_medium_eviction(self):
"""Test 5: Medium complexity eviction scenarios with shared prefixes
Tenant1: limit 10 chars
Tenant2: limit 7 chars (forces one string to be evicted)
Tree structure after inserts:
└── 'h' [t1, t2]
├── 'i' [t1, t2] # Oldest for t2
└── 'e' [t1, t2]
├── 'llo' [t1, t2]
└── 'y' [t2] # Newest for t2
Size calculations:
tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars
tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars
After eviction (tenant2 exceeds limit by 1 char):
"hi" should be removed from tenant2 as it's the oldest access
"""
max_size = {
"tenant1": 10,
"tenant2": 6,
} # tenant2 will need to evict one string
# Create a tree with overlapping prefixes
self.tree.insert("hi", "tenant1")
self.tree.insert("hi", "tenant2") # OLDEST for t2
self.tree.insert("hello", "tenant1")
self.tree.insert("hello", "tenant2")
self.tree.insert("hey", "tenant2") # NEWEST for t2
# Verify initial sizes
sizes_before = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_before["tenant1"], 6) # h(1) + i(1) + e(1) + llo(3) = 6
self.assertEqual(
sizes_before["tenant2"], 7
) # h(1) + i(1) + e(1) + llo(3) + y(1) = 7
print("\nTree before eviction:")
print(self.tree.pretty_print())
# Evict - should remove "hi" from tenant2 as it's the oldest
self.tree.evict_tenant_data(max_size)
print("\nTree after eviction:")
print(self.tree.pretty_print())
# Verify sizes after eviction
sizes_after = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_after["tenant1"], 6) # Should be unchanged
self.assertEqual(sizes_after["tenant2"], 6) # h(1) + e(1) + llo(3) + y(1) = 6
def test_advanced_eviction(self):
...
# Create 4 tenants
# Each tenants keeps adding strings with shared prefixes to thousands usage
# Set a strict limit for each tenant to only 100
# At the end, check whether all of the tenant is under 100 after eviction
max_size = {"tenant1": 100, "tenant2": 100, "tenant3": 100, "tenant4": 100}
prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]
for i in range(100):
for j, prefix in enumerate(prefixes):
random_suffix = "".join(random.choices(string.ascii_letters, k=10))
self.tree.insert(prefix + random_suffix, f"tenant{j+1}")
sizes_before = self.tree.get_used_size_per_tenant()
print(sizes_before)
self.tree.evict_tenant_data(max_size)
sizes_after = self.tree.get_used_size_per_tenant()
print(sizes_after)
# ensure size_after is below max_size
for tenant, size in sizes_after.items():
self.assertLessEqual(size, max_size[tenant])
if __name__ == "__main__":
unittest.main()
import time
from collections import defaultdict
from typing import Dict, List
class Node:
def __init__(self):
self.children: Dict[str, Node] = dict()
# We choose to use text because most of the use cases are text-to-text,
# so we can save the tokenizing overhead.
self.text: str = ""
# Maps tenant_id to their last access timestamp
self.tenant_last_access_time: Dict[str, float] = dict()
self.parent = None
def shared_prefix_length(s1, s2):
min_length = min(len(s1), len(s2))
for i in range(min_length):
if s1[i] != s2[i]:
return i
return min_length
class MultiTenantRadixTree:
"""
Python Reference of Rust implementation of MultiTenantRadixTree
MultiTenantRadixTree is the overlap of multiple radix trees by different tenant
Each node in the tree can be owned by multiple tenants, allowing for efficient storage of common prefixes
while maintaining tenant isolation.
Key concepts:
- Tenant: An entity that owns a subset of the stored strings
- Each node tracks which tenants have access to it via tenant_last_access_time
- The tree structure is shared, but queries can be filtered by tenant_id
"""
def __init__(self):
self.root = Node()
def insert(self, s: str, tenant_id: str) -> None:
"""
Insert string 's' and associate it with the given tenant_id.
Args:
s: The string to insert
tenant_id: The identifier of the tenant who owns this string
"""
curr = self.root
curr_idx = 0
curr.tenant_last_access_time[tenant_id] = time.time()
while curr_idx < len(s):
matched_node = None
if s[curr_idx] in curr.children:
matched_node = curr.children[s[curr_idx]]
if matched_node is None:
# No match => create a new node
new_node = Node()
new_node.text = s[curr_idx:]
new_node.parent = curr
curr.children[s[curr_idx]] = new_node
curr_idx = len(s)
curr = new_node
curr.tenant_last_access_time[tenant_id] = time.time()
else:
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
# 1. If the matched text is shorter than the node text => split the node
if shared_len < len(matched_node.text):
# Split structure: [matched_node] => [new_node] -> [contracted_matched_node]
matched_text = matched_node.text[:shared_len]
unmatched_text = matched_node.text[shared_len:]
new_node = Node()
new_node.text = matched_text
new_node.children = {unmatched_text[0]: matched_node}
new_node.parent = curr
new_node.parent.children[matched_text[0]] = new_node
new_node.tenant_last_access_time = (
matched_node.tenant_last_access_time.copy()
)
# Contract matched node
matched_node.text = unmatched_text
matched_node.parent = new_node
curr_idx += shared_len
curr = new_node
curr.tenant_last_access_time[tenant_id] = time.time()
# 2. If the matched text is longer or equal to the node text => walk down the node
else:
curr_idx += shared_len
curr = matched_node
curr.tenant_last_access_time[tenant_id] = time.time()
def prefix_match(self, s: str) -> tuple[str, int]:
"""
Match string 's' with multiple tenants' trees in one operation.
Args:
s: The string to match
Returns:
Tuple(str, int): The longest prefix of 's' that matches the tree and the first tenant_id that own the matched prefix
"""
curr = self.root
curr_idx = 0
ret_text = ""
ret_tenant = None
while curr_idx < len(s):
matched_node = None
if s[curr_idx] in curr.children:
matched_node = curr.children[s[curr_idx]]
if matched_node is None:
break
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
if shared_len == len(matched_node.text):
curr_idx += shared_len
curr = matched_node
else:
curr_idx += shared_len
curr = matched_node
break
selected_tenant = list(curr.tenant_last_access_time.keys())[0]
# traverse back to the root to update last access time for the selected tenant
while curr != self.root:
curr.tenant_last_access_time[selected_tenant] = time.time()
curr = curr.parent
return s[:curr_idx], selected_tenant
def evict_tenant_data(self, max_size_per_tenant: Dict[str, int]) -> None:
"""
Evict data for tenants that have exceeded their storage limits.
Args:
max_size_per_tenant: Dictionary mapping tenant_id to their maximum allowed storage size
"""
def leaf_of(node):
"""
If the node is a leaf for a tenant, add tenant_id to the return list
This will return list of tenant ids
If not a leaf for all tenants, return []
"""
candidates = dict([(k, True) for k in node.tenant_last_access_time.keys()])
for n in node.children.values():
for c in n.tenant_last_access_time.keys():
candidates[c] = False
return [k for k, v in candidates.items() if v]
# maintain a heap with (time, tenant, node) as the value
import heapq
# 1. traverse the tree to
# a. add all the leaves into a heap (a node with N tenants will be added N times into the heap)
# b. calculate the used size for each tenant
# do a dfs with stack
stack = [self.root]
pq = []
used_size_per_tenant = defaultdict(int)
while stack:
curr = stack.pop()
for t in curr.tenant_last_access_time.keys():
used_size_per_tenant[t] += len(curr.text)
for c in curr.children.values():
stack.append(c)
# if the node is a leaf for a tenant, add the tenant to the heap
tenants = leaf_of(curr)
for t in tenants:
heapq.heappush(pq, (curr.tenant_last_access_time[t], t, curr))
# 2. pop the heap
# a. if the tenant's used size is less than the limit, continue
# b. if the tenant's used size is greater than the limit, remove the leaf and update the used size, and add its parent to the heap
while len(pq) > 0:
time, tenant, node = heapq.heappop(pq)
if used_size_per_tenant[tenant] <= max_size_per_tenant[tenant]:
continue
# remove the leaf
used_size_per_tenant[tenant] -= len(node.text)
del node.tenant_last_access_time[tenant]
# if no children and no tenants, remove the node
if len(node.children) == 0 and len(node.tenant_last_access_time) == 0:
del node.parent.children[node.text[0]]
# add its parent to the heap
if tenant in leaf_of(node.parent):
heapq.heappush(
pq,
(node.parent.tenant_last_access_time[tenant], tenant, node.parent),
)
def get_used_size_per_tenant(self) -> Dict[str, int]:
"""
Calculate the used storage size for each tenant.
Returns:
Dict[str, int]: A dictionary mapping tenant_id to their used storage size
"""
used_size_per_tenant = defaultdict(int)
stack = [self.root]
while stack:
curr = stack.pop()
for t in curr.tenant_last_access_time.keys():
used_size_per_tenant[t] += len(curr.text)
for c in curr.children.values():
stack.append(c)
return used_size_per_tenant
def remove_tenant(self, tenant_id: str) -> None:
"""
Remove all data associated with a specific tenant from the tree.
This operation maintains the integrity of the shared tree structure while
removing only the specified tenant's access information.
Args:
tenant_id: The identifier of the tenant whose data should be removed
"""
# TODO: Implementation needed
pass
def pretty_print(self) -> str:
"""
Returns a string representation of the tree showing the structure, tenant ownership,
and leaf status for each node.
Returns:
str: A formatted string showing the tree hierarchy with tenant information
"""
def _node_to_str(node: Node, prefix: str = "", is_last: bool = True) -> str:
# Current node representation
node_str = prefix
node_str += "└── " if is_last else "├── "
# Add node text
node_str += f"'{node.text}' ["
# Add tenant information including both timestamp and leaf status
tenant_info = []
for tid, ts in node.tenant_last_access_time.items():
time_str = (
time.strftime("%H:%M:%S.", time.localtime(ts))
+ f"{(ts % 1):0.3f}"[2:]
)
tenant_info.append(f"{tid} | {time_str}")
node_str += ", ".join(tenant_info)
node_str += "]\n"
# Handle children
children = list(node.children.items())
for i, (char, child) in enumerate(children):
is_last_child = i == len(children) - 1
# Adjust prefix for children based on whether this is the last child
new_prefix = prefix + (" " if is_last else "│ ")
node_str += _node_to_str(child, new_prefix, is_last_child)
return node_str
if not self.root.children:
return "Empty tree"
# Start with root's children since root itself is just an empty node
result = ""
children = list(self.root.children.items())
for i, (char, child) in enumerate(children):
is_last = i == len(children) - 1
result += _node_to_str(child, "", is_last)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment