Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
30af7dfb
Unverified
Commit
30af7dfb
authored
Nov 21, 2024
by
Byron Hsu
Committed by
GitHub
Nov 21, 2024
Browse files
[router] add base_gpu_id server args & merged radix tree python reference (#2115)
parent
f6f71379
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
513 additions
and
2 deletions
+513
-2
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
scripts/playground/router/test_tree.py
scripts/playground/router/test_tree.py
+207
-0
scripts/playground/router/tree.py
scripts/playground/router/tree.py
+292
-0
No files found.
python/sglang/srt/managers/data_parallel_controller.py
View file @
30af7dfb
...
@@ -156,7 +156,7 @@ class DataParallelController:
...
@@ -156,7 +156,7 @@ class DataParallelController:
)
)
for
tp_rank
in
tp_rank_range
:
for
tp_rank
in
tp_rank_range
:
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
base_gpu_id
+
tp_rank
%
tp_size_per_node
gpu_id
=
server_args
.
base_gpu_id
+
base_gpu_id
+
tp_rank
%
tp_size_per_node
proc
=
mp
.
Process
(
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
target
=
run_scheduler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
dp_rank
,
writer
),
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
dp_rank
,
writer
),
...
...
python/sglang/srt/managers/scheduler.py
View file @
30af7dfb
...
@@ -1380,6 +1380,10 @@ def run_scheduler_process(
...
@@ -1380,6 +1380,10 @@ def run_scheduler_process(
dp_rank
:
Optional
[
int
],
dp_rank
:
Optional
[
int
],
pipe_writer
,
pipe_writer
,
):
):
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
if
dp_rank
is
None
:
dp_rank
=
int
(
os
.
getenv
(
"DP_RANK"
,
-
1
))
if
dp_rank
is
None
:
if
dp_rank
is
None
:
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
else
:
else
:
...
...
python/sglang/srt/server.py
View file @
30af7dfb
...
@@ -418,7 +418,7 @@ def launch_engine(
...
@@ -418,7 +418,7 @@ def launch_engine(
)
)
for
tp_rank
in
tp_rank_range
:
for
tp_rank
in
tp_rank_range
:
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
tp_rank
%
tp_size_per_node
gpu_id
=
server_args
.
base_gpu_id
+
tp_rank
%
tp_size_per_node
proc
=
mp
.
Process
(
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
target
=
run_scheduler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
None
,
writer
),
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
None
,
writer
),
...
...
python/sglang/srt/server_args.py
View file @
30af7dfb
...
@@ -72,6 +72,7 @@ class ServerArgs:
...
@@ -72,6 +72,7 @@ class ServerArgs:
constrained_json_whitespace_pattern
:
Optional
[
str
]
=
None
constrained_json_whitespace_pattern
:
Optional
[
str
]
=
None
watchdog_timeout
:
float
=
300
watchdog_timeout
:
float
=
300
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
base_gpu_id
:
int
=
0
# Logging
# Logging
log_level
:
str
=
"info"
log_level
:
str
=
"info"
...
@@ -412,6 +413,12 @@ class ServerArgs:
...
@@ -412,6 +413,12 @@ class ServerArgs:
default
=
ServerArgs
.
download_dir
,
default
=
ServerArgs
.
download_dir
,
help
=
"Model download directory."
,
help
=
"Model download directory."
,
)
)
parser
.
add_argument
(
"--base-gpu-id"
,
type
=
int
,
default
=
ServerArgs
.
base_gpu_id
,
help
=
"The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine."
,
)
# Logging
# Logging
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -736,6 +743,7 @@ class ServerArgs:
...
@@ -736,6 +743,7 @@ class ServerArgs:
and
(
self
.
lora_paths
is
None
or
self
.
disable_cuda_graph
)
and
(
self
.
lora_paths
is
None
or
self
.
disable_cuda_graph
)
and
(
self
.
lora_paths
is
None
or
self
.
disable_radix_cache
)
and
(
self
.
lora_paths
is
None
or
self
.
disable_radix_cache
)
),
"compatibility of lora and cuda graph and radix attention is in progress"
),
"compatibility of lora and cuda graph and radix attention is in progress"
assert
self
.
base_gpu_id
>=
0
,
"base_gpu_id must be non-negative"
if
isinstance
(
self
.
lora_paths
,
list
):
if
isinstance
(
self
.
lora_paths
,
list
):
lora_paths
=
self
.
lora_paths
lora_paths
=
self
.
lora_paths
...
...
scripts/playground/router/test_tree.py
0 → 100644
View file @
30af7dfb
import
random
import
string
import
time
import
unittest
from
typing
import
Dict
,
List
,
Tuple
from
tree
import
MultiTenantRadixTree
class
TestMultiTenantRadixTree
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
tree
=
MultiTenantRadixTree
()
def
test_insert_exact_match
(
self
):
"""Test 1: Basic insert and exact match operations"""
# Insert a single string for one tenant
self
.
tree
.
insert
(
"hello"
,
"tenant1"
)
matched
,
tenant
=
self
.
tree
.
prefix_match
(
"hello"
)
self
.
assertEqual
(
matched
,
"hello"
)
self
.
assertEqual
(
tenant
,
"tenant1"
)
# Insert same string for different tenant
self
.
tree
.
insert
(
"hello"
,
"tenant2"
)
matched
,
tenant
=
self
.
tree
.
prefix_match
(
"hello"
)
self
.
assertIn
(
tenant
,
[
"tenant1"
,
"tenant2"
])
# Insert different string for same tenant
self
.
tree
.
insert
(
"world"
,
"tenant1"
)
matched
,
tenant
=
self
.
tree
.
prefix_match
(
"world"
)
self
.
assertEqual
(
matched
,
"world"
)
self
.
assertEqual
(
tenant
,
"tenant1"
)
print
(
self
.
tree
.
pretty_print
())
def
test_insert_partial_match
(
self
):
"""Test 2: Insert with partial matching scenarios"""
# Test partial matches with common prefixes
self
.
tree
.
insert
(
"hello"
,
"tenant1"
)
print
(
self
.
tree
.
pretty_print
())
self
.
tree
.
insert
(
"help"
,
"tenant2"
)
print
(
self
.
tree
.
pretty_print
())
# Match exact strings
matched
,
tenant
=
self
.
tree
.
prefix_match
(
"hello"
)
self
.
assertEqual
(
matched
,
"hello"
)
self
.
assertEqual
(
tenant
,
"tenant1"
)
matched
,
tenant
=
self
.
tree
.
prefix_match
(
"help"
)
self
.
assertEqual
(
matched
,
"help"
)
self
.
assertEqual
(
tenant
,
"tenant2"
)
# Match partial string
matched
,
tenant
=
self
.
tree
.
prefix_match
(
"hel"
)
self
.
assertEqual
(
matched
,
"hel"
)
self
.
assertIn
(
tenant
,
[
"tenant1"
,
"tenant2"
])
# Match longer string
matched
,
tenant
=
self
.
tree
.
prefix_match
(
"hello_world"
)
self
.
assertEqual
(
matched
,
"hello"
)
self
.
assertEqual
(
tenant
,
"tenant1"
)
def
test_insert_edge_cases
(
self
):
"""Test 3: Edge cases for insert and match operations"""
# Empty string
self
.
tree
.
insert
(
""
,
"tenant1"
)
matched
,
tenant
=
self
.
tree
.
prefix_match
(
""
)
self
.
assertEqual
(
matched
,
""
)
self
.
assertEqual
(
tenant
,
"tenant1"
)
# Single character
self
.
tree
.
insert
(
"a"
,
"tenant1"
)
matched
,
tenant
=
self
.
tree
.
prefix_match
(
"a"
)
self
.
assertEqual
(
matched
,
"a"
)
self
.
assertEqual
(
tenant
,
"tenant1"
)
# Very long string
long_str
=
"a"
*
1000
self
.
tree
.
insert
(
long_str
,
"tenant1"
)
matched
,
tenant
=
self
.
tree
.
prefix_match
(
long_str
)
self
.
assertEqual
(
matched
,
long_str
)
self
.
assertEqual
(
tenant
,
"tenant1"
)
# Unicode characters
self
.
tree
.
insert
(
"你好"
,
"tenant1"
)
matched
,
tenant
=
self
.
tree
.
prefix_match
(
"你好"
)
self
.
assertEqual
(
matched
,
"你好"
)
self
.
assertEqual
(
tenant
,
"tenant1"
)
def
test_simple_eviction
(
self
):
"""Test 4: Simple eviction scenarios
Tenant1: limit 10 chars
Tenant2: limit 5 chars
Should demonstrate:
1. Basic eviction when size limit exceeded
2. Proper eviction based on last access time
3. Verification that shared nodes remain intact for other tenants
"""
# Set up size limits
max_size
=
{
"tenant1"
:
10
,
"tenant2"
:
5
}
# Insert strings for both tenants
self
.
tree
.
insert
(
"hello"
,
"tenant1"
)
# size 5
self
.
tree
.
insert
(
"hello"
,
"tenant2"
)
# size 5
self
.
tree
.
insert
(
"world"
,
"tenant2"
)
# size 5, total for tenant2 = 10
# Verify initial sizes
sizes_before
=
self
.
tree
.
get_used_size_per_tenant
()
self
.
assertEqual
(
sizes_before
[
"tenant1"
],
5
)
# "hello" = 5
self
.
assertEqual
(
sizes_before
[
"tenant2"
],
10
)
# "hello" + "world" = 10
# Evict - should remove "hello" from tenant2 as it's the oldest
self
.
tree
.
evict_tenant_data
(
max_size
)
# Verify sizes after eviction
sizes_after
=
self
.
tree
.
get_used_size_per_tenant
()
self
.
assertEqual
(
sizes_after
[
"tenant1"
],
5
)
# Should be unchanged
self
.
assertEqual
(
sizes_after
[
"tenant2"
],
5
)
# Only "world" remains
# Verify "world" remains for tenant2 (was accessed more recently)
matched
,
tenant
=
self
.
tree
.
prefix_match
(
"world"
)
self
.
assertEqual
(
matched
,
"world"
)
self
.
assertEqual
(
tenant
,
"tenant2"
)
def
test_medium_eviction
(
self
):
"""Test 5: Medium complexity eviction scenarios with shared prefixes
Tenant1: limit 10 chars
Tenant2: limit 7 chars (forces one string to be evicted)
Tree structure after inserts:
└── 'h' [t1, t2]
├── 'i' [t1, t2] # Oldest for t2
└── 'e' [t1, t2]
├── 'llo' [t1, t2]
└── 'y' [t2] # Newest for t2
Size calculations:
tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars
tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars
After eviction (tenant2 exceeds limit by 1 char):
"hi" should be removed from tenant2 as it's the oldest access
"""
max_size
=
{
"tenant1"
:
10
,
"tenant2"
:
6
,
}
# tenant2 will need to evict one string
# Create a tree with overlapping prefixes
self
.
tree
.
insert
(
"hi"
,
"tenant1"
)
self
.
tree
.
insert
(
"hi"
,
"tenant2"
)
# OLDEST for t2
self
.
tree
.
insert
(
"hello"
,
"tenant1"
)
self
.
tree
.
insert
(
"hello"
,
"tenant2"
)
self
.
tree
.
insert
(
"hey"
,
"tenant2"
)
# NEWEST for t2
# Verify initial sizes
sizes_before
=
self
.
tree
.
get_used_size_per_tenant
()
self
.
assertEqual
(
sizes_before
[
"tenant1"
],
6
)
# h(1) + i(1) + e(1) + llo(3) = 6
self
.
assertEqual
(
sizes_before
[
"tenant2"
],
7
)
# h(1) + i(1) + e(1) + llo(3) + y(1) = 7
print
(
"
\n
Tree before eviction:"
)
print
(
self
.
tree
.
pretty_print
())
# Evict - should remove "hi" from tenant2 as it's the oldest
self
.
tree
.
evict_tenant_data
(
max_size
)
print
(
"
\n
Tree after eviction:"
)
print
(
self
.
tree
.
pretty_print
())
# Verify sizes after eviction
sizes_after
=
self
.
tree
.
get_used_size_per_tenant
()
self
.
assertEqual
(
sizes_after
[
"tenant1"
],
6
)
# Should be unchanged
self
.
assertEqual
(
sizes_after
[
"tenant2"
],
6
)
# h(1) + e(1) + llo(3) + y(1) = 6
def
test_advanced_eviction
(
self
):
...
# Create 4 tenants
# Each tenants keeps adding strings with shared prefixes to thousands usage
# Set a strict limit for each tenant to only 100
# At the end, check whether all of the tenant is under 100 after eviction
max_size
=
{
"tenant1"
:
100
,
"tenant2"
:
100
,
"tenant3"
:
100
,
"tenant4"
:
100
}
prefixes
=
[
"aqwefcisdf"
,
"iajsdfkmade"
,
"kjnzxcvewqe"
,
"iejksduqasd"
]
for
i
in
range
(
100
):
for
j
,
prefix
in
enumerate
(
prefixes
):
random_suffix
=
""
.
join
(
random
.
choices
(
string
.
ascii_letters
,
k
=
10
))
self
.
tree
.
insert
(
prefix
+
random_suffix
,
f
"tenant
{
j
+
1
}
"
)
sizes_before
=
self
.
tree
.
get_used_size_per_tenant
()
print
(
sizes_before
)
self
.
tree
.
evict_tenant_data
(
max_size
)
sizes_after
=
self
.
tree
.
get_used_size_per_tenant
()
print
(
sizes_after
)
# ensure size_after is below max_size
for
tenant
,
size
in
sizes_after
.
items
():
self
.
assertLessEqual
(
size
,
max_size
[
tenant
])
if
__name__
==
"__main__"
:
unittest
.
main
()
scripts/playground/router/tree.py
0 → 100644
View file @
30af7dfb
import
time
from
collections
import
defaultdict
from
typing
import
Dict
,
List
class
Node
:
def
__init__
(
self
):
self
.
children
:
Dict
[
str
,
Node
]
=
dict
()
# We choose to use text because most of the use cases are text-to-text,
# so we can save the tokenizing overhead.
self
.
text
:
str
=
""
# Maps tenant_id to their last access timestamp
self
.
tenant_last_access_time
:
Dict
[
str
,
float
]
=
dict
()
self
.
parent
=
None
def
shared_prefix_length
(
s1
,
s2
):
min_length
=
min
(
len
(
s1
),
len
(
s2
))
for
i
in
range
(
min_length
):
if
s1
[
i
]
!=
s2
[
i
]:
return
i
return
min_length
class
MultiTenantRadixTree
:
"""
Python Reference of Rust implementation of MultiTenantRadixTree
MultiTenantRadixTree is the overlap of multiple radix trees by different tenant
Each node in the tree can be owned by multiple tenants, allowing for efficient storage of common prefixes
while maintaining tenant isolation.
Key concepts:
- Tenant: An entity that owns a subset of the stored strings
- Each node tracks which tenants have access to it via tenant_last_access_time
- The tree structure is shared, but queries can be filtered by tenant_id
"""
def
__init__
(
self
):
self
.
root
=
Node
()
def
insert
(
self
,
s
:
str
,
tenant_id
:
str
)
->
None
:
"""
Insert string 's' and associate it with the given tenant_id.
Args:
s: The string to insert
tenant_id: The identifier of the tenant who owns this string
"""
curr
=
self
.
root
curr_idx
=
0
curr
.
tenant_last_access_time
[
tenant_id
]
=
time
.
time
()
while
curr_idx
<
len
(
s
):
matched_node
=
None
if
s
[
curr_idx
]
in
curr
.
children
:
matched_node
=
curr
.
children
[
s
[
curr_idx
]]
if
matched_node
is
None
:
# No match => create a new node
new_node
=
Node
()
new_node
.
text
=
s
[
curr_idx
:]
new_node
.
parent
=
curr
curr
.
children
[
s
[
curr_idx
]]
=
new_node
curr_idx
=
len
(
s
)
curr
=
new_node
curr
.
tenant_last_access_time
[
tenant_id
]
=
time
.
time
()
else
:
shared_len
=
shared_prefix_length
(
s
[
curr_idx
:],
matched_node
.
text
)
# 1. If the matched text is shorter than the node text => split the node
if
shared_len
<
len
(
matched_node
.
text
):
# Split structure: [matched_node] => [new_node] -> [contracted_matched_node]
matched_text
=
matched_node
.
text
[:
shared_len
]
unmatched_text
=
matched_node
.
text
[
shared_len
:]
new_node
=
Node
()
new_node
.
text
=
matched_text
new_node
.
children
=
{
unmatched_text
[
0
]:
matched_node
}
new_node
.
parent
=
curr
new_node
.
parent
.
children
[
matched_text
[
0
]]
=
new_node
new_node
.
tenant_last_access_time
=
(
matched_node
.
tenant_last_access_time
.
copy
()
)
# Contract matched node
matched_node
.
text
=
unmatched_text
matched_node
.
parent
=
new_node
curr_idx
+=
shared_len
curr
=
new_node
curr
.
tenant_last_access_time
[
tenant_id
]
=
time
.
time
()
# 2. If the matched text is longer or equal to the node text => walk down the node
else
:
curr_idx
+=
shared_len
curr
=
matched_node
curr
.
tenant_last_access_time
[
tenant_id
]
=
time
.
time
()
def
prefix_match
(
self
,
s
:
str
)
->
tuple
[
str
,
int
]:
"""
Match string 's' with multiple tenants' trees in one operation.
Args:
s: The string to match
Returns:
Tuple(str, int): The longest prefix of 's' that matches the tree and the first tenant_id that own the matched prefix
"""
curr
=
self
.
root
curr_idx
=
0
ret_text
=
""
ret_tenant
=
None
while
curr_idx
<
len
(
s
):
matched_node
=
None
if
s
[
curr_idx
]
in
curr
.
children
:
matched_node
=
curr
.
children
[
s
[
curr_idx
]]
if
matched_node
is
None
:
break
shared_len
=
shared_prefix_length
(
s
[
curr_idx
:],
matched_node
.
text
)
if
shared_len
==
len
(
matched_node
.
text
):
curr_idx
+=
shared_len
curr
=
matched_node
else
:
curr_idx
+=
shared_len
curr
=
matched_node
break
selected_tenant
=
list
(
curr
.
tenant_last_access_time
.
keys
())[
0
]
# traverse back to the root to update last access time for the selected tenant
while
curr
!=
self
.
root
:
curr
.
tenant_last_access_time
[
selected_tenant
]
=
time
.
time
()
curr
=
curr
.
parent
return
s
[:
curr_idx
],
selected_tenant
def
evict_tenant_data
(
self
,
max_size_per_tenant
:
Dict
[
str
,
int
])
->
None
:
"""
Evict data for tenants that have exceeded their storage limits.
Args:
max_size_per_tenant: Dictionary mapping tenant_id to their maximum allowed storage size
"""
def
leaf_of
(
node
):
"""
If the node is a leaf for a tenant, add tenant_id to the return list
This will return list of tenant ids
If not a leaf for all tenants, return []
"""
candidates
=
dict
([(
k
,
True
)
for
k
in
node
.
tenant_last_access_time
.
keys
()])
for
n
in
node
.
children
.
values
():
for
c
in
n
.
tenant_last_access_time
.
keys
():
candidates
[
c
]
=
False
return
[
k
for
k
,
v
in
candidates
.
items
()
if
v
]
# maintain a heap with (time, tenant, node) as the value
import
heapq
# 1. traverse the tree to
# a. add all the leaves into a heap (a node with N tenants will be added N times into the heap)
# b. calculate the used size for each tenant
# do a dfs with stack
stack
=
[
self
.
root
]
pq
=
[]
used_size_per_tenant
=
defaultdict
(
int
)
while
stack
:
curr
=
stack
.
pop
()
for
t
in
curr
.
tenant_last_access_time
.
keys
():
used_size_per_tenant
[
t
]
+=
len
(
curr
.
text
)
for
c
in
curr
.
children
.
values
():
stack
.
append
(
c
)
# if the node is a leaf for a tenant, add the tenant to the heap
tenants
=
leaf_of
(
curr
)
for
t
in
tenants
:
heapq
.
heappush
(
pq
,
(
curr
.
tenant_last_access_time
[
t
],
t
,
curr
))
# 2. pop the heap
# a. if the tenant's used size is less than the limit, continue
# b. if the tenant's used size is greater than the limit, remove the leaf and update the used size, and add its parent to the heap
while
len
(
pq
)
>
0
:
time
,
tenant
,
node
=
heapq
.
heappop
(
pq
)
if
used_size_per_tenant
[
tenant
]
<=
max_size_per_tenant
[
tenant
]:
continue
# remove the leaf
used_size_per_tenant
[
tenant
]
-=
len
(
node
.
text
)
del
node
.
tenant_last_access_time
[
tenant
]
# if no children and no tenants, remove the node
if
len
(
node
.
children
)
==
0
and
len
(
node
.
tenant_last_access_time
)
==
0
:
del
node
.
parent
.
children
[
node
.
text
[
0
]]
# add its parent to the heap
if
tenant
in
leaf_of
(
node
.
parent
):
heapq
.
heappush
(
pq
,
(
node
.
parent
.
tenant_last_access_time
[
tenant
],
tenant
,
node
.
parent
),
)
def
get_used_size_per_tenant
(
self
)
->
Dict
[
str
,
int
]:
"""
Calculate the used storage size for each tenant.
Returns:
Dict[str, int]: A dictionary mapping tenant_id to their used storage size
"""
used_size_per_tenant
=
defaultdict
(
int
)
stack
=
[
self
.
root
]
while
stack
:
curr
=
stack
.
pop
()
for
t
in
curr
.
tenant_last_access_time
.
keys
():
used_size_per_tenant
[
t
]
+=
len
(
curr
.
text
)
for
c
in
curr
.
children
.
values
():
stack
.
append
(
c
)
return
used_size_per_tenant
def
remove_tenant
(
self
,
tenant_id
:
str
)
->
None
:
"""
Remove all data associated with a specific tenant from the tree.
This operation maintains the integrity of the shared tree structure while
removing only the specified tenant's access information.
Args:
tenant_id: The identifier of the tenant whose data should be removed
"""
# TODO: Implementation needed
pass
def
pretty_print
(
self
)
->
str
:
"""
Returns a string representation of the tree showing the structure, tenant ownership,
and leaf status for each node.
Returns:
str: A formatted string showing the tree hierarchy with tenant information
"""
def
_node_to_str
(
node
:
Node
,
prefix
:
str
=
""
,
is_last
:
bool
=
True
)
->
str
:
# Current node representation
node_str
=
prefix
node_str
+=
"└── "
if
is_last
else
"├── "
# Add node text
node_str
+=
f
"'
{
node
.
text
}
' ["
# Add tenant information including both timestamp and leaf status
tenant_info
=
[]
for
tid
,
ts
in
node
.
tenant_last_access_time
.
items
():
time_str
=
(
time
.
strftime
(
"%H:%M:%S."
,
time
.
localtime
(
ts
))
+
f
"
{
(
ts
%
1
):
0.3
f
}
"
[
2
:]
)
tenant_info
.
append
(
f
"
{
tid
}
|
{
time_str
}
"
)
node_str
+=
", "
.
join
(
tenant_info
)
node_str
+=
"]
\n
"
# Handle children
children
=
list
(
node
.
children
.
items
())
for
i
,
(
char
,
child
)
in
enumerate
(
children
):
is_last_child
=
i
==
len
(
children
)
-
1
# Adjust prefix for children based on whether this is the last child
new_prefix
=
prefix
+
(
" "
if
is_last
else
"│ "
)
node_str
+=
_node_to_str
(
child
,
new_prefix
,
is_last_child
)
return
node_str
if
not
self
.
root
.
children
:
return
"Empty tree"
# Start with root's children since root itself is just an empty node
result
=
""
children
=
list
(
self
.
root
.
children
.
items
())
for
i
,
(
char
,
child
)
in
enumerate
(
children
):
is_last
=
i
==
len
(
children
)
-
1
result
+=
_node_to_str
(
child
,
""
,
is_last
)
return
result
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment