Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
47367b76
Unverified
Commit
47367b76
authored
Jun 19, 2025
by
DarkSharpness
Committed by
GitHub
Jun 20, 2025
Browse files
[Refactor] Clean up radix cache related API (#7303)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
650127a1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
153 additions
and
122 deletions
+153
-122
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+15
-18
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+26
-32
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+8
-18
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+52
-8
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+7
-13
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+29
-21
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+16
-12
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
47367b76
...
...
@@ -38,7 +38,7 @@ import logging
import
threading
from
enum
import
Enum
,
auto
from
http
import
HTTPStatus
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -436,7 +436,7 @@ class Req:
self
,
rid
:
str
,
origin_input_text
:
str
,
origin_input_ids
:
Tuple
[
int
],
origin_input_ids
:
List
[
int
],
sampling_params
:
SamplingParams
,
return_logprob
:
bool
=
False
,
top_logprobs_num
:
int
=
0
,
...
...
@@ -467,7 +467,7 @@ class Req:
# Each decode stage's output ids
self
.
output_ids
=
[]
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
self
.
fill_ids
=
None
self
.
fill_ids
=
[]
self
.
session_id
=
session_id
self
.
input_embeds
=
input_embeds
...
...
@@ -519,13 +519,14 @@ class Req:
# Prefix info
# The indices to kv cache for the shared prefix.
self
.
prefix_indices
=
[]
self
.
prefix_indices
:
torch
.
Tensor
=
[]
# Number of tokens to run prefill.
self
.
extend_input_len
=
0
# The relative logprob_start_len in an extend batch
self
.
extend_logprob_start_len
=
0
self
.
last_node
=
None
self
.
last_node_global
=
None
self
.
last_node
:
Any
=
None
self
.
last_host_node
:
Any
=
None
self
.
host_hit_length
=
0
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
...
...
@@ -644,21 +645,17 @@ class Req:
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
,
enable_hierarchical_cache
=
False
,
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
if
tree_cache
is
not
None
:
# tree cache is None if the prefix is not computed with tree cache.
if
enable_hierarchical_cache
:
self
.
prefix_indices
,
self
.
last_node
,
self
.
last_node_global
=
(
tree_cache
.
match_prefix
(
key
=
self
.
adjust_max_prefix_ids
(),
include_evicted
=
True
)
)
else
:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
(
self
.
prefix_indices
,
self
.
last_node
,
self
.
last_host_node
,
self
.
host_hit_length
,
)
=
tree_cache
.
match_prefix
(
key
=
self
.
adjust_max_prefix_ids
(),
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
47367b76
...
...
@@ -90,7 +90,7 @@ class SchedulePolicy:
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
])
->
bool
:
if
self
.
policy
==
CacheAgnosticPolicy
.
FCFS
:
# A shortcut for FCFS
return
return
False
policy
=
self
.
_determine_active_policy
(
waiting_queue
)
...
...
@@ -134,7 +134,7 @@ class SchedulePolicy:
"""
try
:
policy_enum
=
CacheAwarePolicy
(
policy
)
if
tree_cache
.
disable
:
if
getattr
(
tree_cache
,
"
disable
"
,
True
)
:
# If tree_cache is disabled, using CacheAgnosticPolicy policy
return
CacheAgnosticPolicy
.
FCFS
return
policy_enum
...
...
@@ -158,14 +158,9 @@ class SchedulePolicy:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
if
self
.
enable_hierarchical_cache
:
r
.
prefix_indices
,
r
.
last_node
,
r
.
last_node_global
=
(
self
.
tree_cache
.
match_prefix
(
key
=
prefix_ids
,
include_evicted
=
True
)
)
else
:
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
r
.
prefix_indices
,
r
.
last_node
,
r
.
last_host_node
,
r
.
host_hit_length
=
(
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
)
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
...
...
@@ -175,7 +170,7 @@ class SchedulePolicy:
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if
len
(
r
.
prefix_indices
)
<=
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD
:
in_batch_matching_prefixes
,
_
=
(
in_batch_matching_prefixes
,
_
,
_
,
_
=
(
self
.
waiting_queue_radix_tree
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
...
...
@@ -268,6 +263,7 @@ class AddReqResult(Enum):
class
PrefillAdder
:
def
__init__
(
self
,
page_size
:
int
,
tree_cache
:
BasePrefixCache
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
running_batch
:
ScheduleBatch
,
...
...
@@ -276,6 +272,7 @@ class PrefillAdder:
rem_chunk_tokens
:
Optional
[
int
],
mixed_with_decode_tokens
:
int
=
0
,
):
self
.
page_size
=
page_size
self
.
tree_cache
=
tree_cache
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
self
.
running_batch
=
running_batch
...
...
@@ -442,46 +439,43 @@ class PrefillAdder:
return
self
.
budget_state
()
def
add_one_req
(
self
,
req
:
Req
,
has_chunked_req
:
bool
,
enable_hierarchical_cache
:
bool
=
False
):
def
add_one_req
(
self
,
req
:
Req
,
has_chunked_req
:
bool
):
if
req
.
sampling_params
.
ignore_eos
and
getattr
(
self
.
tree_cache
,
"disable"
,
True
):
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
total_tokens
=
req
.
extend_input_len
+
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS_ESTIMATION
)
input_tokens
=
(
-
(
-
req
.
extend_input_len
//
self
.
tree_cache
.
page_size
)
*
self
.
tree_cache
.
page_size
)
# adjusting the input_tokens based on host_hit_length and
page_size
real_input_tokens
=
req
.
extend_input_len
-
req
.
host_hit_length
real_input_tokens
=
-
(
-
real_input_tokens
//
self
.
page_size
)
*
self
.
page_size
prefix_len
=
len
(
req
.
prefix_indices
)
if
total_tokens
>=
self
.
rem_total_tokens
:
return
AddReqResult
.
NO_TOKEN
if
input_tokens
>
self
.
rem_input_tokens
and
len
(
self
.
can_run_list
)
!=
0
:
if
real_
input_tokens
>
=
self
.
rem_input_tokens
and
len
(
self
.
can_run_list
)
!=
0
:
return
AddReqResult
.
OTHER
with
self
.
_lock_node
(
req
.
last_node
):
if
total_tokens
>
self
.
rem_total_tokens
:
# self.rem_total_tokens may decrease after the lock acquisition
if
total_tokens
>=
self
.
rem_total_tokens
:
return
AddReqResult
.
NO_TOKEN
if
(
enable_hierarchical_cache
and
req
.
last_node_global
is
not
None
and
req
.
last_node_global
.
evicted
):
req
.
last_node
,
req
.
prefix_indices
=
self
.
tree_cache
.
init_load_back
(
req
.
last_node_global
,
req
.
prefix_indices
if
req
.
host_hit_length
>
0
:
new_indices
,
req
.
last_node
=
self
.
tree_cache
.
init_load_back
(
req
.
last_host_node
,
req
.
host_hit_length
)
req
.
prefix_indices
=
torch
.
cat
([
req
.
prefix_indices
,
new_indices
])
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
input_tokens
=
(
-
(
-
req
.
extend_input_len
//
self
.
tree_cache
.
page_size
)
*
self
.
tree_cache
.
page_size
)
prefix_len
=
len
(
req
.
prefix_indices
)
input_tokens
=
-
(
-
req
.
extend_input_len
//
self
.
page_size
)
*
self
.
page_size
if
input_tokens
>=
self
.
rem_input_tokens
and
len
(
self
.
can_run_list
)
!=
0
:
return
AddReqResult
.
OTHER
if
self
.
rem_chunk_tokens
is
None
or
input_tokens
<=
self
.
rem_chunk_tokens
:
# Non-chunked prefill
self
.
can_run_list
.
append
(
req
)
...
...
@@ -496,7 +490,7 @@ class PrefillAdder:
)
else
:
# Make sure at least one page is available
trunc_len
=
self
.
rem_chunk_tokens
-
self
.
tree_cache
.
page_size
+
1
trunc_len
=
self
.
rem_chunk_tokens
-
self
.
page_size
+
1
if
trunc_len
<=
0
:
return
AddReqResult
.
OTHER
...
...
python/sglang/srt/managers/scheduler.py
View file @
47367b76
...
...
@@ -1467,15 +1467,14 @@ class Scheduler(
return
None
if
self
.
enable_hierarchical_cache
:
# check for completion of hierarchical cache activities to release memory
self
.
tree_cache
.
writing_check
()
self
.
tree_cache
.
loading_check
()
self
.
tree_cache
.
check_hicache_events
()
# Get priority queue
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
# Prefill policy
adder
=
PrefillAdder
(
self
.
page_size
,
self
.
tree_cache
,
self
.
token_to_kv_pool_allocator
,
self
.
running_batch
,
...
...
@@ -1517,19 +1516,8 @@ class Scheduler(
self
.
running_batch
.
batch_is_full
=
True
break
# bypass prefix_computed if enable_hierarchical_cache
req
.
init_next_round_input
(
(
None
if
(
prefix_computed
and
not
self
.
enable_hierarchical_cache
)
else
self
.
tree_cache
),
self
.
enable_hierarchical_cache
,
)
res
=
adder
.
add_one_req
(
req
,
self
.
chunked_req
,
self
.
enable_hierarchical_cache
)
req
.
init_next_round_input
(
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
,
has_chunked_req
=
(
self
.
chunked_req
is
not
None
))
if
res
!=
AddReqResult
.
CONTINUE
:
if
res
==
AddReqResult
.
NO_TOKEN
:
...
...
@@ -1581,7 +1569,9 @@ class Scheduler(
)
if
self
.
enable_hierarchical_cache
:
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
new_batch
.
hicache_consumer_index
=
self
.
tree_cache
.
ready_to_load_cache
()
new_batch
.
hicache_consumer_index
=
(
self
.
tree_cache
.
ready_to_load_host_cache
()
)
new_batch
.
prepare_for_extend
()
...
...
python/sglang/srt/mem_cache/base_prefix_cache.py
View file @
47367b76
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
List
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
List
,
NamedTuple
,
Tuple
import
torch
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
else
:
Req
=
Any
# Placeholder for Req type when not type checking
class
MatchResult
(
NamedTuple
):
"""Result of a prefix match operation.
Attributes:
device_indices : Indices of the KV cache on the device matched by common prefix.
last_device_node: The last TreeNode on the device that was matched.
last_host_node : The last TreeNode on the host that was matched.
Note that if HiCache is not enabled,
this **must** be the same as `last_device_node`.
host_hit_length : Length of the KV cache hit on the host, if applicable.
0 if HiCache is not enabled.
"""
device_indices
:
torch
.
Tensor
last_device_node
:
Any
last_host_node
:
Any
host_hit_length
:
int
=
0
class
BasePrefixCache
(
ABC
):
...
...
@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
pass
@
abstractmethod
def
match_prefix
(
self
,
**
kwargs
)
->
Tuple
[
List
[
int
],
int
]
:
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
)
->
MatchResult
:
pass
@
abstractmethod
def
insert
(
self
,
**
kwargs
):
def
cache_finished_req
(
self
,
req
:
Req
,
**
kwargs
):
pass
@
abstractmethod
def
cache_finished_req
(
self
,
**
kwargs
):
pass
@
abstractmethod
def
cache_unfinished_req
(
self
,
**
kwargs
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
**
kwargs
):
pass
@
abstractmethod
...
...
@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
def
pretty_print
(
self
):
raise
NotImplementedError
()
def
init_load_back
(
self
,
last_host_node
:
Any
,
host_hit_length
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
Any
]:
"""
Preparing KV cache loading from host to device.
"""
raise
NotImplementedError
()
def
ready_to_load_host_cache
(
self
)
->
Any
:
"""
Notify the cache controller to start the KV cache loading
"""
raise
NotImplementedError
()
def
check_hicache_events
(
self
)
->
Any
:
"""
Check HiCache related activities to update radix tree and synchronize across TP workers if needed
"""
raise
NotImplementedError
()
def
take_events
(
self
):
return
[]
python/sglang/srt/mem_cache/chunk_cache.py
View file @
47367b76
...
...
@@ -6,19 +6,13 @@ from typing import TYPE_CHECKING, Any, Callable, List, Tuple
import
torch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
class
ChunkCacheEntry
:
def
__init__
(
self
,
rid
:
str
,
value
:
torch
.
Tensor
):
self
.
rid
=
rid
self
.
value
=
value
class
ChunkCache
(
BasePrefixCache
):
def
__init__
(
self
,
...
...
@@ -29,13 +23,16 @@ class ChunkCache(BasePrefixCache):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
self
.
page_size
=
page_size
self
.
disable
=
True
def
reset
(
self
):
pass
def
match_prefix
(
self
,
**
unused_kwargs
)
->
Tuple
[
List
[
int
],
int
]:
return
[],
None
def
match_prefix
(
self
,
**
unused_kwargs
)
->
MatchResult
:
return
MatchResult
(
device_indices
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
),
last_device_node
=
None
,
last_host_node
=
None
,
)
def
cache_finished_req
(
self
,
req
:
Req
):
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
...
...
@@ -54,9 +51,6 @@ class ChunkCache(BasePrefixCache):
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req
.
prefix_indices
=
kv_indices
def
insert
(
self
):
raise
NotImplementedError
()
def
evict
(
self
,
num_tokens
:
int
):
pass
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
47367b76
...
...
@@ -7,6 +7,7 @@ from typing import List, Optional
import
torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.base_prefix_cache
import
MatchResult
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MLATokenToKVPool
,
...
...
@@ -283,41 +284,44 @@ class HiRadixCache(RadixCache):
def
init_load_back
(
self
,
last_node
:
TreeNode
,
prefix_indices
:
torch
.
Tensor
,
host_hit_length
:
int
,
mem_quota
:
Optional
[
int
]
=
None
,
):
assert
(
len
(
prefix_indices
)
==
0
or
prefix_indices
.
is_cuda
),
"indices of device kV caches should be on GPU"
_
=
host_hit_length
# unused, but kept for compatibility
if
last_node
.
evicted
:
loading_values
=
self
.
load_back
(
last_node
,
mem_quota
)
if
loading_values
is
not
None
:
prefix_indices
=
(
loading_values
if
len
(
prefix_indices
)
==
0
else
torch
.
cat
([
prefix_indices
,
loading_values
])
)
logger
.
debug
(
f
"loading back
{
len
(
loading_values
)
}
tokens for node
{
last_node
.
id
}
"
)
return
loading_values
,
last_node
while
last_node
.
evicted
:
last_node
=
last_node
.
parent
return
last_node
,
prefix_indices
return
(
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
),
last_node
,
)
def
ready_to_load_cache
(
self
):
def
ready_to_load_
host_
cache
(
self
):
producer_index
=
self
.
cache_controller
.
layer_done_counter
.
next_producer
()
self
.
load_cache_event
.
set
()
return
producer_index
def
match_prefix
(
self
,
key
:
List
[
int
],
include_evicted
=
False
,
**
kwargs
):
def
check_hicache_events
(
self
):
self
.
writing_check
()
self
.
loading_check
()
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
):
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
if
self
.
disable
or
len
(
key
)
==
0
:
if
include_evicted
:
return
empty_value
,
self
.
root_node
,
self
.
root_node
else
:
return
empty_value
,
self
.
root_node
return
MatchResult
(
device_indices
=
empty_value
,
last_device_node
=
self
.
root_node
,
last_host_node
=
self
.
root_node
,
host_hit_length
=
0
,
)
if
self
.
page_size
!=
1
:
page_aligned_len
=
len
(
key
)
//
self
.
page_size
*
self
.
page_size
...
...
@@ -329,14 +333,18 @@ class HiRadixCache(RadixCache):
else
:
value
=
empty_value
last_node_global
=
last_node
host_hit_length
=
0
last_host_node
=
last_node
while
last_node
.
evicted
:
host_hit_length
+=
len
(
last_node
.
host_value
)
last_node
=
last_node
.
parent
if
include_evicted
:
return
value
,
last_node
,
last_node_global
else
:
return
value
,
last_node
return
MatchResult
(
device_indices
=
value
,
last_device_node
=
last_node
,
last_host_node
=
last_host_node
,
host_hit_length
=
host_hit_length
,
)
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
node
.
last_access_time
=
time
.
monotonic
()
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
47367b76
...
...
@@ -33,8 +33,7 @@ from sglang.srt.disaggregation.kv_events import (
BlockStored
,
KVCacheEvent
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
if
TYPE_CHECKING
:
...
...
@@ -47,9 +46,9 @@ class TreeNode:
def
__init__
(
self
,
id
:
Optional
[
int
]
=
None
):
self
.
children
=
defaultdict
(
TreeNode
)
self
.
parent
=
None
self
.
key
=
None
self
.
value
=
None
self
.
parent
:
TreeNode
=
None
self
.
key
:
List
[
int
]
=
None
self
.
value
:
Optional
[
torch
.
Tensor
]
=
None
self
.
lock_ref
=
0
self
.
last_access_time
=
time
.
monotonic
()
...
...
@@ -57,7 +56,7 @@ class TreeNode:
# indicating the node is loading KV cache from host
self
.
loading
=
False
# store the host indices of KV cache
self
.
host_value
=
None
self
.
host_value
:
Optional
[
torch
.
Tensor
]
=
None
self
.
id
=
TreeNode
.
counter
if
id
is
None
else
id
TreeNode
.
counter
+=
1
...
...
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
self
.
protected_size_
=
0
self
.
_record_all_cleared_event
()
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
int
]
:
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
)
->
MatchResult
:
"""Find the matching prefix from the radix tree.
Args:
key: A list of token IDs to find a matching prefix.
...
...
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
than the last node's value.
"""
if
self
.
disable
or
len
(
key
)
==
0
:
return
(
torch
.
empty
(
return
MatchResult
(
device_indices
=
torch
.
empty
(
(
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
,
),
self
.
root_node
,
last_device_node
=
self
.
root_node
,
last_host_node
=
self
.
root_node
,
)
if
self
.
page_size
!=
1
:
...
...
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
value
=
torch
.
cat
(
value
)
else
:
value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
return
value
,
last_node
return
MatchResult
(
device_indices
=
value
,
last_device_node
=
last_node
,
last_host_node
=
last_node
,
)
def
insert
(
self
,
key
:
List
,
value
=
None
):
if
self
.
disable
:
...
...
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
)
# The prefix indices could be updated, reuse it
new_indices
,
new_last_node
=
self
.
match_prefix
(
page_aligned_token_ids
)
new_indices
,
new_last_node
,
_
,
_
=
self
.
match_prefix
(
page_aligned_token_ids
)
self
.
req_to_token_pool
.
write
(
(
req
.
req_pool_idx
,
slice
(
len
(
req
.
prefix_indices
),
len
(
new_indices
))),
new_indices
[
len
(
req
.
prefix_indices
)
:],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment