Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7623091d
Unverified
Commit
7623091d
authored
Aug 07, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 07, 2024
Browse files
RadixCache method adjust (#977)
parent
f724f1f1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
141 additions
and
119 deletions
+141
-119
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+22
-13
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+37
-60
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+5
-1
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+29
-15
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+48
-30
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
7623091d
...
@@ -124,7 +124,7 @@ class Req:
...
@@ -124,7 +124,7 @@ class Req:
# For vision input
# For vision input
self
.
pixel_values
=
None
self
.
pixel_values
=
None
self
.
image_size
=
None
self
.
image_size
=
None
self
.
image_offset
=
0
self
.
image_offset
=
None
self
.
pad_value
=
None
self
.
pad_value
=
None
# Prefix info
# Prefix info
...
@@ -162,6 +162,13 @@ class Req:
...
@@ -162,6 +162,13 @@ class Req:
def
finished
(
self
)
->
bool
:
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
return
self
.
finished_reason
is
not
None
def
adjust_max_prefix_ids
(
self
):
max_prefix_ids
=
self
.
input_ids
if
self
.
return_logprob
:
max_prefix_ids
=
self
.
input_ids
[:
self
.
logprob_start_len
]
return
max_prefix_ids
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def
init_incremental_detokenize
(
self
):
def
init_incremental_detokenize
(
self
):
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
...
@@ -444,7 +451,8 @@ class ScheduleBatch:
...
@@ -444,7 +451,8 @@ class ScheduleBatch:
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[
self
.
image_offsets
=
[
r
.
image_offset
-
p_len
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
(
r
.
image_offset
-
p_len
)
if
r
.
image_offset
is
not
None
else
0
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
]
]
self
.
prefix_lens
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
prefix_lens
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
...
@@ -596,15 +604,7 @@ class ScheduleBatch:
...
@@ -596,15 +604,7 @@ class ScheduleBatch:
req
.
vid
+=
1
req
.
vid
+=
1
# insert the old request into tree_cache
# insert the old request into tree_cache
self
.
tree_cache
.
cache_req
(
self
.
tree_cache
.
cache_finished_req
(
req
,
cur_all_ids
)
rid
=
req
.
rid
,
token_ids
=
cur_all_ids
,
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req
.
req_pool_idx
,
)
# unlock the last node
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
# re-applying image padding
# re-applying image padding
if
req
.
pixel_values
is
not
None
:
if
req
.
pixel_values
is
not
None
:
...
@@ -621,8 +621,7 @@ class ScheduleBatch:
...
@@ -621,8 +621,7 @@ class ScheduleBatch:
jump_forward_reqs
.
append
(
req
)
jump_forward_reqs
.
append
(
req
)
filter_indices
.
remove
(
i
)
filter_indices
.
remove
(
i
)
if
len
(
filter_indices
)
<
len
(
self
.
reqs
):
self
.
filter_batch
(
filter_indices
)
self
.
filter_batch
(
filter_indices
)
return
jump_forward_reqs
return
jump_forward_reqs
...
@@ -644,6 +643,15 @@ class ScheduleBatch:
...
@@ -644,6 +643,15 @@ class ScheduleBatch:
]
=
self
.
out_cache_loc
]
=
self
.
out_cache_loc
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
]):
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
]):
if
unfinished_indices
is
None
or
len
(
unfinished_indices
)
==
0
:
# Filter out all requests
self
.
reqs
=
[]
return
if
len
(
unfinished_indices
)
==
len
(
self
.
reqs
):
# No need to filter
return
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
unfinished_indices
]
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
unfinished_indices
]
new_indices
=
torch
.
tensor
(
unfinished_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
new_indices
=
torch
.
tensor
(
unfinished_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
self
.
seq_lens
[
new_indices
]
self
.
seq_lens
=
self
.
seq_lens
[
new_indices
]
...
@@ -711,6 +719,7 @@ class ScheduleBatch:
...
@@ -711,6 +719,7 @@ class ScheduleBatch:
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
def
sample
(
self
,
logits
:
torch
.
Tensor
):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
# Post process logits
logits
=
logits
.
contiguous
()
logits
=
logits
.
contiguous
()
logits
.
div_
(
self
.
temperatures
)
logits
.
div_
(
self
.
temperatures
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
7623091d
...
@@ -232,8 +232,6 @@ class ModelTpServer:
...
@@ -232,8 +232,6 @@ class ModelTpServer:
if
new_batch
is
not
None
:
if
new_batch
is
not
None
:
# Run a new prefill batch
# Run a new prefill batch
self
.
forward_prefill_batch
(
new_batch
)
self
.
forward_prefill_batch
(
new_batch
)
self
.
cache_filled_batch
(
new_batch
)
self
.
filter_out_inflight
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
if
self
.
running_batch
is
None
:
...
@@ -353,26 +351,20 @@ class ModelTpServer:
...
@@ -353,26 +351,20 @@ class ModelTpServer:
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
# TODO(lsyin): organize this function
running_bs
=
(
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
)
if
running_bs
>=
self
.
max_running_requests
:
if
running_bs
>=
self
.
max_running_requests
:
return
return
None
# Compute matched prefix length
# Compute matched prefix length
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
try_match_ids
=
req
.
input_ids
if
req
.
return_logprob
:
try_match_ids
=
req
.
input_ids
[:
req
.
logprob_start_len
]
# NOTE: the prefix_indices must always be aligned with last_node
# NOTE: the prefix_indices must always be aligned with last_node
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
prefix_indices
,
req
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
req
.
rid
,
key
=
try_match
_ids
rid
=
req
.
rid
,
key
=
req
.
adjust_max_prefix
_ids
()
)
)
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
req
.
prefix_indices
)
req
.
prefix_indices
=
prefix_indices
req
.
last_node
=
last_node
# Get priority queue
# Get priority queue
self
.
waiting_queue
=
self
.
scheduler
.
get_priority_queue
(
self
.
waiting_queue
)
self
.
waiting_queue
=
self
.
scheduler
.
get_priority_queue
(
self
.
waiting_queue
)
...
@@ -394,6 +386,24 @@ class ModelTpServer:
...
@@ -394,6 +386,24 @@ class ModelTpServer:
)
)
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
# FIXME: Move this code into adjust_max_prefix_len
if
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
:
# Need at least two tokens to compute normalized logprob
if
req
.
extend_input_len
<
2
:
delta
=
2
-
req
.
extend_input_len
req
.
extend_input_len
+=
delta
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
delta
if
req
.
extend_input_len
==
0
and
req
.
sampling_params
.
max_new_tokens
>
0
:
# Need at least one token to compute logits
req
.
extend_input_len
=
1
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
1
res
=
adder
.
add_one_req
(
req
)
res
=
adder
.
add_one_req
(
req
)
if
(
if
(
not
res
not
res
...
@@ -470,10 +480,20 @@ class ModelTpServer:
...
@@ -470,10 +480,20 @@ class ModelTpServer:
pt
=
0
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
is
not
self
.
current_inflight_req
:
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
.
return_logprob
:
if
req
.
return_logprob
:
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
pt
+=
req
.
extend_input_len
pt
+=
req
.
extend_input_len
...
@@ -529,22 +549,6 @@ class ModelTpServer:
...
@@ -529,22 +549,6 @@ class ModelTpServer:
)
)
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
def
cache_filled_batch
(
self
,
batch
:
ScheduleBatch
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
tuple
(
req
.
input_ids
),
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req
.
req_pool_idx
,
del_in_memory_pool
=
False
,
old_last_node
=
req
.
last_node
,
)
req
.
prefix_indices
,
req
.
last_node
=
new_prefix_indices
,
new_last_node
if
req
is
self
.
current_inflight_req
:
# inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
# Check if decode out of memory
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
if
not
batch
.
check_decode_mem
():
...
@@ -595,6 +599,9 @@ class ModelTpServer:
...
@@ -595,6 +599,9 @@ class ModelTpServer:
req
.
output_ids
.
append
(
next_token_id
)
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
if
req
.
return_logprob
:
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
(
next_token_logprobs
[
i
],
next_token_id
)
...
@@ -614,12 +621,9 @@ class ModelTpServer:
...
@@ -614,12 +621,9 @@ class ModelTpServer:
output_spaces_between_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_meta_info
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
finished_indices
=
[]
unfinished_indices
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
finished
():
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
finished_indices
.
append
(
i
)
else
:
unfinished_indices
.
append
(
i
)
unfinished_indices
.
append
(
i
)
if
req
.
finished
()
or
(
if
req
.
finished
()
or
(
...
@@ -683,34 +687,7 @@ class ModelTpServer:
...
@@ -683,34 +687,7 @@ class ModelTpServer:
)
)
)
)
# Remove finished reqs
# Remove finished reqs: update batch tensors
if
finished_indices
:
# Update radix cache
for
i
in
finished_indices
:
req
=
batch
.
reqs
[
i
]
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
tuple
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req
.
req_pool_idx
,
)
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
# Update batch tensors
if
unfinished_indices
:
batch
.
filter_batch
(
unfinished_indices
)
else
:
batch
.
reqs
=
[]
def
filter_out_inflight
(
self
,
batch
:
ScheduleBatch
):
# TODO(lsyin): reduce the overhead, make a special version for this
if
self
.
current_inflight_req
is
None
:
return
to_remove
=
batch
.
reqs
.
index
(
self
.
current_inflight_req
)
unfinished_indices
=
[
i
for
i
in
range
(
len
(
batch
.
reqs
))
if
i
!=
to_remove
]
batch
.
filter_batch
(
unfinished_indices
)
batch
.
filter_batch
(
unfinished_indices
)
def
flush_cache
(
self
):
def
flush_cache
(
self
):
...
...
python/sglang/srt/mem_cache/base_cache.py
→
python/sglang/srt/mem_cache/base_
prefix_
cache.py
View file @
7623091d
...
@@ -17,7 +17,11 @@ class BasePrefixCache(ABC):
...
@@ -17,7 +17,11 @@ class BasePrefixCache(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
cache_req
(
self
,
**
kwargs
):
def
cache_finished_req
(
self
,
**
kwargs
):
pass
@
abstractmethod
def
cache_unfinished_req
(
self
,
**
kwargs
):
pass
pass
@
abstractmethod
@
abstractmethod
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
7623091d
"""Cache for chunked prefill, used when RadixCache is disabled."""
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
sglang.srt.mem_cache.base_cache
import
BasePrefixCache
from
typing
import
TYPE_CHECKING
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
class
ChunkCacheEntry
:
class
ChunkCacheEntry
:
...
@@ -27,22 +32,31 @@ class ChunkCache(BasePrefixCache):
...
@@ -27,22 +32,31 @@ class ChunkCache(BasePrefixCache):
entry
=
self
.
entries
[
rid
]
entry
=
self
.
entries
[
rid
]
return
entry
.
value
,
entry
return
entry
.
value
,
entry
def
cache_req
(
def
cache_finished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
self
,
rid
,
token_ids
,
req_pool_idx
,
del_in_memory_pool
=
True
,
**
kwargs
if
token_ids
is
None
:
):
token_ids
=
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
if
del_in_memory_pool
:
assert
rid
in
self
.
entries
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
indices
)
return
if
rid
not
in
self
.
entries
:
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
self
.
entries
[
rid
]
=
ChunkCacheEntry
(
rid
,
indices
)
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
assert
req
.
rid
in
self
.
entries
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
entry
=
self
.
entries
[
rid
]
def
cache_unfinished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
entry
.
value
=
indices
if
token_ids
is
None
:
return
indices
,
entry
token_ids
=
req
.
input_ids
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
if
req
.
rid
not
in
self
.
entries
:
self
.
entries
[
req
.
rid
]
=
ChunkCacheEntry
(
req
.
rid
,
kv_indices
)
entry
=
self
.
entries
[
req
.
rid
]
entry
.
value
=
kv_indices
return
kv_indices
,
entry
def
insert
(
self
):
def
insert
(
self
):
raise
NotImplementedError
raise
NotImplementedError
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
7623091d
...
@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache.
...
@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache.
import
heapq
import
heapq
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
import
torch
import
torch
from
sglang.srt.mem_cache.base_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
class
TreeNode
:
class
TreeNode
:
...
@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache):
...
@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache):
value
=
[
x
for
x
in
key
]
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_req
(
def
cache_finished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
self
,
"""Cache request when it finishes."""
token_ids
,
if
token_ids
is
None
:
last_uncached_pos
,
token_ids
=
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
req_pool_idx
,
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
del_in_memory_pool
=
True
,
req
.
req_pool_idx
,
:
len
(
token_ids
)
old_last_node
=
None
,
]
**
kwargs
,
):
# Insert the request into radix cache
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
new_prefix_len
=
self
.
insert
(
token_ids
,
indices
.
clone
())
if
self
.
disable
:
if
self
.
disable
:
if
del_in_memory_pool
:
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
token_to_kv_pool
.
free
(
indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
else
:
return
return
torch
.
tensor
([],
dtype
=
torch
.
int32
),
self
.
root_node
# Radix Cache takes one ref in memory pool
# Radix Cache takes one ref in memory pool
self
.
token_to_kv_pool
.
free
(
indices
[
last_uncached_pos
:
new_prefix_len
])
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
if
del_in_memory_pool
:
# Remove req slot release the cache lock
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
else
:
self
.
dec_lock_ref
(
req
.
last_node
)
cached_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
assert
len
(
cached_indices
)
==
len
(
token_ids
)
def
cache_unfinished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
"""Cache request when it is unfinished."""
self
.
req_to_token_pool
.
req_to_token
[
if
self
.
disable
:
req_pool_idx
,
last_uncached_pos
:
len
(
cached_indices
)
return
]
=
cached_indices
[
last_uncached_pos
:]
self
.
dec_lock_ref
(
old_last_node
)
if
token_ids
is
None
:
self
.
inc_lock_ref
(
new_last_node
)
token_ids
=
req
.
input_ids
return
cached_indices
,
new_last_node
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
# The prefix indices could be updated, reuse it
new_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
assert
len
(
new_indices
)
==
len
(
token_ids
)
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
len
(
req
.
prefix_indices
)
:
len
(
new_indices
)
]
=
new_indices
[
len
(
req
.
prefix_indices
)
:]
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
inc_lock_ref
(
new_last_node
)
req
.
prefix_indices
=
new_indices
req
.
last_node
=
new_last_node
def
pretty_print
(
self
):
def
pretty_print
(
self
):
self
.
_print_helper
(
self
.
root_node
,
0
)
self
.
_print_helper
(
self
.
root_node
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment