Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9dae4078
Unverified
Commit
9dae4078
authored
Aug 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 11, 2024
Browse files
Improve type annotation (#1029)
parent
fcc0f5ed
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
83 additions
and
41 deletions
+83
-41
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+17
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+28
-10
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+4
-3
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+10
-5
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-2
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+22
-12
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
9dae4078
...
@@ -18,13 +18,15 @@ limitations under the License.
...
@@ -18,13 +18,15 @@ limitations under the License.
import
random
import
random
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
List
from
typing
import
Dict
,
List
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.radix_cache
import
TreeNode
class
PolicyScheduler
:
class
PolicyScheduler
:
def
__init__
(
self
,
policy
,
tree_cache
):
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy
=
"fcfs"
policy
=
"fcfs"
...
@@ -72,12 +74,18 @@ class PolicyScheduler:
...
@@ -72,12 +74,18 @@ class PolicyScheduler:
else
:
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
def
calc_weight
(
self
,
cur_node
,
node_to_weight
):
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
for
child
in
cur_node
.
children
.
values
():
for
child
in
cur_node
.
children
.
values
():
self
.
calc_weight
(
child
,
node_to_weight
)
self
.
calc_weight
(
child
,
node_to_weight
)
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
def
get_dfs_priority
(
self
,
cur_node
,
node_to_priority
,
last_node_to_reqs
,
q
):
def
get_dfs_priority
(
self
,
cur_node
:
TreeNode
,
node_to_priority
:
Dict
,
last_node_to_reqs
:
Dict
,
q
:
List
,
):
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
for
child
in
childs
:
...
@@ -88,10 +96,10 @@ class PolicyScheduler:
...
@@ -88,10 +96,10 @@ class PolicyScheduler:
class
PrefillAdder
:
class
PrefillAdder
:
def
__init__
(
def
__init__
(
self
,
self
,
tree_cache
,
tree_cache
:
BasePrefixCache
,
rem_total_tokens
,
rem_total_tokens
:
int
,
rem_input_tokens
,
rem_input_tokens
:
int
,
rem_chunk_tokens
,
rem_chunk_tokens
:
int
,
):
):
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
self
.
rem_total_tokens
=
rem_total_tokens
self
.
rem_total_tokens
=
rem_total_tokens
...
@@ -151,7 +159,7 @@ class PrefillAdder:
...
@@ -151,7 +159,7 @@ class PrefillAdder:
return
req
if
truncated
else
None
return
req
if
truncated
else
None
@
contextmanager
@
contextmanager
def
_lock_node
(
self
,
last_node
):
def
_lock_node
(
self
,
last_node
:
TreeNode
):
try
:
try
:
delta
=
self
.
tree_cache
.
inc_lock_ref
(
last_node
)
delta
=
self
.
tree_cache
.
inc_lock_ref
(
last_node
)
self
.
rem_total_tokens
+=
delta
self
.
rem_total_tokens
+=
delta
...
...
python/sglang/srt/managers/tp_worker.py
View file @
9dae4078
...
@@ -21,15 +21,17 @@ import os
...
@@ -21,15 +21,17 @@ import os
import
pickle
import
pickle
import
time
import
time
import
warnings
import
warnings
from
typing
import
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch
import
torch
import
torch.distributed
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchEmbeddingOut
,
BatchEmbeddingOut
,
...
@@ -62,6 +64,10 @@ from sglang.utils import get_exception_traceback
...
@@ -62,6 +64,10 @@ from sglang.utils import get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Rename "CI" to "SGLANG_IS_IN_CI".
crash_on_warning
=
os
.
getenv
(
"CI"
,
"false"
)
==
"true"
class
ModelTpServer
:
class
ModelTpServer
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -198,7 +204,7 @@ class ModelTpServer:
...
@@ -198,7 +204,7 @@ class ModelTpServer:
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
def
exposed_step
(
self
,
recv_reqs
):
def
exposed_step
(
self
,
recv_reqs
:
List
):
try
:
try
:
# Recv requests
# Recv requests
for
recv_req
in
recv_reqs
:
for
recv_req
in
recv_reqs
:
...
@@ -247,7 +253,7 @@ class ModelTpServer:
...
@@ -247,7 +253,7 @@ class ModelTpServer:
# Print stats
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_stats
()
self
.
print_
decode_
stats
()
if
self
.
running_batch
.
is_empty
():
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
self
.
running_batch
=
None
...
@@ -259,7 +265,7 @@ class ModelTpServer:
...
@@ -259,7 +265,7 @@ class ModelTpServer:
self
.
check_memory
()
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_stats
(
self
):
def
print_
decode_
stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
)
...
@@ -276,7 +282,6 @@ class ModelTpServer:
...
@@ -276,7 +282,6 @@ class ModelTpServer:
)
)
def
check_memory
(
self
):
def
check_memory
(
self
):
crash
=
os
.
getenv
(
"CI"
,
"false"
)
==
"true"
available_size
=
(
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
)
...
@@ -286,7 +291,7 @@ class ModelTpServer:
...
@@ -286,7 +291,7 @@ class ModelTpServer:
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
"KV cache pool leak detected!"
)
)
exit
(
1
)
if
crash
else
None
exit
(
1
)
if
crash
_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
warnings
.
warn
(
...
@@ -295,7 +300,7 @@ class ModelTpServer:
...
@@ -295,7 +300,7 @@ class ModelTpServer:
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
"Memory pool leak detected!"
)
)
exit
(
1
)
if
crash
else
None
exit
(
1
)
if
crash
_on_warning
else
None
def
handle_generate_request
(
def
handle_generate_request
(
self
,
self
,
...
@@ -511,7 +516,14 @@ class ModelTpServer:
...
@@ -511,7 +516,14 @@ class ModelTpServer:
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
def
add_logprob_return_values
(
self
,
i
,
req
:
Req
,
pt
,
next_token_ids
,
output
):
def
add_logprob_return_values
(
self
,
i
,
req
:
Req
,
pt
:
int
,
next_token_ids
:
List
[
int
],
output
:
LogitProcessorOutput
,
):
if
req
.
normalized_prompt_logprob
is
None
:
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
...
@@ -786,7 +798,11 @@ def run_tp_server(
...
@@ -786,7 +798,11 @@ def run_tp_server(
def
launch_tp_servers
(
def
launch_tp_servers
(
gpu_ids
,
tp_rank_range
,
server_args
,
nccl_port
,
model_overide_args
gpu_ids
:
List
[
int
],
tp_rank_range
:
List
[
int
],
server_args
:
ServerArgs
,
nccl_port
:
int
,
model_overide_args
:
dict
,
):
):
"""Launch multiple tensor parallel servers."""
"""Launch multiple tensor parallel servers."""
procs
=
[]
procs
=
[]
...
@@ -801,7 +817,9 @@ def launch_tp_servers(
...
@@ -801,7 +817,9 @@ def launch_tp_servers(
return
procs
return
procs
def
broadcast_recv_input
(
data
,
rank
,
dist_group
):
def
broadcast_recv_input
(
data
:
Any
,
rank
:
int
,
dist_group
:
torch
.
distributed
.
ProcessGroup
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if
rank
==
0
:
if
rank
==
0
:
...
...
python/sglang/srt/mem_cache/base_prefix_cache.py
View file @
9dae4078
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
class
BasePrefixCache
(
ABC
):
class
BasePrefixCache
(
ABC
):
...
@@ -25,7 +26,7 @@ class BasePrefixCache(ABC):
...
@@ -25,7 +26,7 @@ class BasePrefixCache(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
evict
(
self
,
num_tokens
,
evict_callback
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -41,7 +42,7 @@ class BasePrefixCache(ABC):
...
@@ -41,7 +42,7 @@ class BasePrefixCache(ABC):
pass
pass
def
total_size
(
self
):
def
total_size
(
self
):
raise
NotImplementedError
raise
NotImplementedError
()
def
pretty_print
(
self
):
def
pretty_print
(
self
):
raise
NotImplementedError
raise
NotImplementedError
()
python/sglang/srt/mem_cache/chunk_cache.py
View file @
9dae4078
from
__future__
import
annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Callable
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
...
@@ -15,7 +18,9 @@ class ChunkCacheEntry:
...
@@ -15,7 +18,9 @@ class ChunkCacheEntry:
class
ChunkCache
(
BasePrefixCache
):
class
ChunkCache
(
BasePrefixCache
):
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
BaseTokenToKVPool
):
self
.
disable
=
True
self
.
disable
=
True
self
.
req_to_token_pool
=
req_to_token_pool
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
...
@@ -32,7 +37,7 @@ class ChunkCache(BasePrefixCache):
...
@@ -32,7 +37,7 @@ class ChunkCache(BasePrefixCache):
entry
=
self
.
entries
[
rid
]
entry
=
self
.
entries
[
rid
]
return
entry
.
value
,
entry
return
entry
.
value
,
entry
def
cache_finished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
if
token_ids
is
None
:
if
token_ids
is
None
:
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
...
@@ -45,7 +50,7 @@ class ChunkCache(BasePrefixCache):
...
@@ -45,7 +50,7 @@ class ChunkCache(BasePrefixCache):
if
req
.
rid
in
self
.
entries
:
if
req
.
rid
in
self
.
entries
:
del
self
.
entries
[
req
.
rid
]
del
self
.
entries
[
req
.
rid
]
def
cache_unfinished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
if
token_ids
is
None
:
if
token_ids
is
None
:
token_ids
=
req
.
fill_ids
token_ids
=
req
.
fill_ids
...
@@ -64,7 +69,7 @@ class ChunkCache(BasePrefixCache):
...
@@ -64,7 +69,7 @@ class ChunkCache(BasePrefixCache):
def
insert
(
self
):
def
insert
(
self
):
raise
NotImplementedError
raise
NotImplementedError
def
evict
(
self
,
num_tokens
,
evict_callback
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
pass
pass
def
inc_lock_ref
(
self
,
node
):
def
inc_lock_ref
(
self
,
node
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
9dae4078
...
@@ -16,7 +16,7 @@ limitations under the License.
...
@@ -16,7 +16,7 @@ limitations under the License.
"""Memory pool."""
"""Memory pool."""
import
logging
import
logging
from
typing
import
List
from
typing
import
List
,
Union
import
torch
import
torch
...
@@ -42,7 +42,7 @@ class ReqToTokenPool:
...
@@ -42,7 +42,7 @@ class ReqToTokenPool:
return
select_index
return
select_index
def
free
(
self
,
free_index
):
def
free
(
self
,
free_index
:
Union
[
int
,
List
[
int
]]
):
if
isinstance
(
free_index
,
(
int
,)):
if
isinstance
(
free_index
,
(
int
,)):
self
.
free_slots
.
append
(
free_index
)
self
.
free_slots
.
append
(
free_index
)
else
:
else
:
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
9dae4078
from
__future__
import
annotations
"""
"""
Copyright 2023-2024 SGLang Team
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -25,6 +27,7 @@ from typing import TYPE_CHECKING
...
@@ -25,6 +27,7 @@ from typing import TYPE_CHECKING
import
torch
import
torch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
...
@@ -43,7 +46,7 @@ class TreeNode:
...
@@ -43,7 +46,7 @@ class TreeNode:
return
self
.
last_access_time
<
other
.
last_access_time
return
self
.
last_access_time
<
other
.
last_access_time
def
_key_match
(
key0
,
key1
):
def
_key_match
(
key0
:
List
,
key1
:
List
):
i
=
0
i
=
0
for
k0
,
k1
in
zip
(
key0
,
key1
):
for
k0
,
k1
in
zip
(
key0
,
key1
):
if
k0
!=
k1
:
if
k0
!=
k1
:
...
@@ -53,7 +56,12 @@ def _key_match(key0, key1):
...
@@ -53,7 +56,12 @@ def _key_match(key0, key1):
class
RadixCache
(
BasePrefixCache
):
class
RadixCache
(
BasePrefixCache
):
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
,
disable
:
bool
=
False
):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
BaseTokenToKVPool
,
disable
:
bool
=
False
,
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
disable
=
disable
self
.
disable
=
disable
...
@@ -68,7 +76,7 @@ class RadixCache(BasePrefixCache):
...
@@ -68,7 +76,7 @@ class RadixCache(BasePrefixCache):
self
.
root_node
.
lock_ref
=
1
self
.
root_node
.
lock_ref
=
1
self
.
evictable_size_
=
0
self
.
evictable_size_
=
0
def
match_prefix
(
self
,
key
,
**
kwargs
):
def
match_prefix
(
self
,
key
:
List
,
**
kwargs
):
if
self
.
disable
:
if
self
.
disable
:
return
[],
self
.
root_node
return
[],
self
.
root_node
...
@@ -81,7 +89,7 @@ class RadixCache(BasePrefixCache):
...
@@ -81,7 +89,7 @@ class RadixCache(BasePrefixCache):
value
=
torch
.
tensor
([],
dtype
=
torch
.
int32
)
value
=
torch
.
tensor
([],
dtype
=
torch
.
int32
)
return
value
,
last_node
[
0
]
return
value
,
last_node
[
0
]
def
insert
(
self
,
key
,
value
=
None
):
def
insert
(
self
,
key
:
List
,
value
=
None
):
if
self
.
disable
:
if
self
.
disable
:
return
0
return
0
...
@@ -89,7 +97,7 @@ class RadixCache(BasePrefixCache):
...
@@ -89,7 +97,7 @@ class RadixCache(BasePrefixCache):
value
=
[
x
for
x
in
key
]
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_finished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
"""Cache request when it finishes."""
"""Cache request when it finishes."""
if
token_ids
is
None
:
if
token_ids
is
None
:
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
...
@@ -110,7 +118,7 @@ class RadixCache(BasePrefixCache):
...
@@ -110,7 +118,7 @@ class RadixCache(BasePrefixCache):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
dec_lock_ref
(
req
.
last_node
)
def
cache_unfinished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
"""Cache request when it is unfinished."""
"""Cache request when it is unfinished."""
if
self
.
disable
:
if
self
.
disable
:
return
return
...
@@ -145,7 +153,7 @@ class RadixCache(BasePrefixCache):
...
@@ -145,7 +153,7 @@ class RadixCache(BasePrefixCache):
def
total_size
(
self
):
def
total_size
(
self
):
return
self
.
_total_size_helper
(
self
.
root_node
)
return
self
.
_total_size_helper
(
self
.
root_node
)
def
evict
(
self
,
num_tokens
,
evict_callback
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
if
self
.
disable
:
if
self
.
disable
:
return
return
...
@@ -199,7 +207,9 @@ class RadixCache(BasePrefixCache):
...
@@ -199,7 +207,9 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions #####
##### Internal Helper Functions #####
def
_match_prefix_helper
(
self
,
node
,
key
,
value
,
last_node
):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
,
last_node
:
TreeNode
):
node
.
last_access_time
=
time
.
time
()
node
.
last_access_time
=
time
.
time
()
if
len
(
key
)
==
0
:
if
len
(
key
)
==
0
:
return
return
...
@@ -216,7 +226,7 @@ class RadixCache(BasePrefixCache):
...
@@ -216,7 +226,7 @@ class RadixCache(BasePrefixCache):
last_node
[
0
]
=
child
last_node
[
0
]
=
child
self
.
_match_prefix_helper
(
child
,
key
[
prefix_len
:],
value
,
last_node
)
self
.
_match_prefix_helper
(
child
,
key
[
prefix_len
:],
value
,
last_node
)
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
):
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
# new_node -> child
# new_node -> child
new_node
=
TreeNode
()
new_node
=
TreeNode
()
new_node
.
children
=
{
key
[
split_len
:][
0
]:
child
}
new_node
.
children
=
{
key
[
split_len
:][
0
]:
child
}
...
@@ -230,7 +240,7 @@ class RadixCache(BasePrefixCache):
...
@@ -230,7 +240,7 @@ class RadixCache(BasePrefixCache):
new_node
.
parent
.
children
[
key
[:
split_len
][
0
]]
=
new_node
new_node
.
parent
.
children
[
key
[:
split_len
][
0
]]
=
new_node
return
new_node
return
new_node
def
_insert_helper
(
self
,
node
,
key
,
value
):
def
_insert_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
):
node
.
last_access_time
=
time
.
time
()
node
.
last_access_time
=
time
.
time
()
if
len
(
key
)
==
0
:
if
len
(
key
)
==
0
:
return
0
return
0
...
@@ -261,7 +271,7 @@ class RadixCache(BasePrefixCache):
...
@@ -261,7 +271,7 @@ class RadixCache(BasePrefixCache):
self
.
evictable_size_
+=
len
(
value
)
self
.
evictable_size_
+=
len
(
value
)
return
0
return
0
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
):
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
:
int
):
for
_
,
child
in
node
.
children
.
items
():
for
_
,
child
in
node
.
children
.
items
():
print
(
" "
*
indent
,
len
(
child
.
key
),
child
.
key
[:
10
],
f
"r=
{
child
.
lock_ref
}
"
)
print
(
" "
*
indent
,
len
(
child
.
key
),
child
.
key
[:
10
],
f
"r=
{
child
.
lock_ref
}
"
)
self
.
_print_helper
(
child
,
indent
=
indent
+
2
)
self
.
_print_helper
(
child
,
indent
=
indent
+
2
)
...
@@ -273,7 +283,7 @@ class RadixCache(BasePrefixCache):
...
@@ -273,7 +283,7 @@ class RadixCache(BasePrefixCache):
del
node
.
parent
.
children
[
k
]
del
node
.
parent
.
children
[
k
]
self
.
evictable_size_
-=
len
(
node
.
key
)
self
.
evictable_size_
-=
len
(
node
.
key
)
def
_total_size_helper
(
self
,
node
):
def
_total_size_helper
(
self
,
node
:
TreeNode
):
x
=
len
(
node
.
value
)
x
=
len
(
node
.
value
)
for
child
in
node
.
children
.
values
():
for
child
in
node
.
children
.
values
():
x
+=
self
.
_total_size_helper
(
child
)
x
+=
self
.
_total_size_helper
(
child
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment