Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7623091d
Unverified
Commit
7623091d
authored
Aug 07, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 07, 2024
Browse files
RadixCache method adjust (#977)
parent
f724f1f1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
141 additions
and
119 deletions
+141
-119
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+22
-13
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+37
-60
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+5
-1
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+29
-15
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+48
-30
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
7623091d
...
...
@@ -124,7 +124,7 @@ class Req:
# For vision input
self
.
pixel_values
=
None
self
.
image_size
=
None
self
.
image_offset
=
0
self
.
image_offset
=
None
self
.
pad_value
=
None
# Prefix info
...
...
@@ -162,6 +162,13 @@ class Req:
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
def
adjust_max_prefix_ids
(
self
):
max_prefix_ids
=
self
.
input_ids
if
self
.
return_logprob
:
max_prefix_ids
=
self
.
input_ids
[:
self
.
logprob_start_len
]
return
max_prefix_ids
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def
init_incremental_detokenize
(
self
):
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
...
...
@@ -444,7 +451,8 @@ class ScheduleBatch:
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[
r
.
image_offset
-
p_len
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
(
r
.
image_offset
-
p_len
)
if
r
.
image_offset
is
not
None
else
0
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
]
self
.
prefix_lens
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
extend_num_tokens
=
extend_num_tokens
...
...
@@ -596,15 +604,7 @@ class ScheduleBatch:
req
.
vid
+=
1
# insert the old request into tree_cache
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
cur_all_ids
,
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req
.
req_pool_idx
,
)
# unlock the last node
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
self
.
tree_cache
.
cache_finished_req
(
req
,
cur_all_ids
)
# re-applying image padding
if
req
.
pixel_values
is
not
None
:
...
...
@@ -621,7 +621,6 @@ class ScheduleBatch:
jump_forward_reqs
.
append
(
req
)
filter_indices
.
remove
(
i
)
if
len
(
filter_indices
)
<
len
(
self
.
reqs
):
self
.
filter_batch
(
filter_indices
)
return
jump_forward_reqs
...
...
@@ -644,6 +643,15 @@ class ScheduleBatch:
]
=
self
.
out_cache_loc
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
]):
if
unfinished_indices
is
None
or
len
(
unfinished_indices
)
==
0
:
# Filter out all requests
self
.
reqs
=
[]
return
if
len
(
unfinished_indices
)
==
len
(
self
.
reqs
):
# No need to filter
return
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
unfinished_indices
]
new_indices
=
torch
.
tensor
(
unfinished_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
self
.
seq_lens
[
new_indices
]
...
...
@@ -711,6 +719,7 @@ class ScheduleBatch:
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
self
.
temperatures
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
7623091d
...
...
@@ -232,8 +232,6 @@ class ModelTpServer:
if
new_batch
is
not
None
:
# Run a new prefill batch
self
.
forward_prefill_batch
(
new_batch
)
self
.
cache_filled_batch
(
new_batch
)
self
.
filter_out_inflight
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
...
...
@@ -353,26 +351,20 @@ class ModelTpServer:
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
# TODO(lsyin): organize this function
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
if
running_bs
>=
self
.
max_running_requests
:
return
return
None
# Compute matched prefix length
for
req
in
self
.
waiting_queue
:
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
try_match_ids
=
req
.
input_ids
if
req
.
return_logprob
:
try_match_ids
=
req
.
input_ids
[:
req
.
logprob_start_len
]
# NOTE: the prefix_indices must always be aligned with last_node
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
req
.
rid
,
key
=
try_match
_ids
req
.
prefix_indices
,
req
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
req
.
rid
,
key
=
req
.
adjust_max_prefix
_ids
()
)
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
req
.
prefix_indices
=
prefix_indices
req
.
last_node
=
last_node
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
req
.
prefix_indices
)
# Get priority queue
self
.
waiting_queue
=
self
.
scheduler
.
get_priority_queue
(
self
.
waiting_queue
)
...
...
@@ -394,6 +386,24 @@ class ModelTpServer:
)
for
req
in
self
.
waiting_queue
:
# FIXME: Move this code into adjust_max_prefix_len
if
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
:
# Need at least two tokens to compute normalized logprob
if
req
.
extend_input_len
<
2
:
delta
=
2
-
req
.
extend_input_len
req
.
extend_input_len
+=
delta
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
delta
if
req
.
extend_input_len
==
0
and
req
.
sampling_params
.
max_new_tokens
>
0
:
# Need at least one token to compute logits
req
.
extend_input_len
=
1
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
1
res
=
adder
.
add_one_req
(
req
)
if
(
not
res
...
...
@@ -470,10 +480,20 @@ class ModelTpServer:
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
.
return_logprob
:
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
pt
+=
req
.
extend_input_len
...
...
@@ -529,22 +549,6 @@ class ModelTpServer:
)
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
def
cache_filled_batch
(
self
,
batch
:
ScheduleBatch
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
tuple
(
req
.
input_ids
),
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req
.
req_pool_idx
,
del_in_memory_pool
=
False
,
old_last_node
=
req
.
last_node
,
)
req
.
prefix_indices
,
req
.
last_node
=
new_prefix_indices
,
new_last_node
if
req
is
self
.
current_inflight_req
:
# inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
...
...
@@ -595,6 +599,9 @@ class ModelTpServer:
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
...
...
@@ -614,12 +621,9 @@ class ModelTpServer:
output_spaces_between_special_tokens
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
finished_indices
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
finished
():
finished_indices
.
append
(
i
)
else
:
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
unfinished_indices
.
append
(
i
)
if
req
.
finished
()
or
(
...
...
@@ -683,34 +687,7 @@ class ModelTpServer:
)
)
# Remove finished reqs
if
finished_indices
:
# Update radix cache
for
i
in
finished_indices
:
req
=
batch
.
reqs
[
i
]
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
tuple
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req
.
req_pool_idx
,
)
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
# Update batch tensors
if
unfinished_indices
:
batch
.
filter_batch
(
unfinished_indices
)
else
:
batch
.
reqs
=
[]
def
filter_out_inflight
(
self
,
batch
:
ScheduleBatch
):
# TODO(lsyin): reduce the overhead, make a special version for this
if
self
.
current_inflight_req
is
None
:
return
to_remove
=
batch
.
reqs
.
index
(
self
.
current_inflight_req
)
unfinished_indices
=
[
i
for
i
in
range
(
len
(
batch
.
reqs
))
if
i
!=
to_remove
]
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
def
flush_cache
(
self
):
...
...
python/sglang/srt/mem_cache/base_cache.py
→
python/sglang/srt/mem_cache/base_
prefix_
cache.py
View file @
7623091d
...
...
@@ -17,7 +17,11 @@ class BasePrefixCache(ABC):
pass
@
abstractmethod
def
cache_req
(
self
,
**
kwargs
):
def
cache_finished_req
(
self
,
**
kwargs
):
pass
@
abstractmethod
def
cache_unfinished_req
(
self
,
**
kwargs
):
pass
@
abstractmethod
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
7623091d
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
sglang.srt.mem_cache.base_cache
import
BasePrefixCache
from
typing
import
TYPE_CHECKING
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
class
ChunkCacheEntry
:
...
...
@@ -27,22 +32,31 @@ class ChunkCache(BasePrefixCache):
entry
=
self
.
entries
[
rid
]
return
entry
.
value
,
entry
def
cache_req
(
self
,
rid
,
token_ids
,
req_pool_idx
,
del_in_memory_pool
=
True
,
**
kwargs
):
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
if
del_in_memory_pool
:
assert
rid
in
self
.
entries
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
indices
)
return
def
cache_finished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
if
token_ids
is
None
:
token_ids
=
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
if
rid
not
in
self
.
entries
:
self
.
entries
[
rid
]
=
ChunkCacheEntry
(
rid
,
indices
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
assert
req
.
rid
in
self
.
entries
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
entry
=
self
.
entries
[
rid
]
entry
.
value
=
indices
return
indices
,
entry
def
cache_unfinished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
if
token_ids
is
None
:
token_ids
=
req
.
input_ids
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
if
req
.
rid
not
in
self
.
entries
:
self
.
entries
[
req
.
rid
]
=
ChunkCacheEntry
(
req
.
rid
,
kv_indices
)
entry
=
self
.
entries
[
req
.
rid
]
entry
.
value
=
kv_indices
return
kv_indices
,
entry
def
insert
(
self
):
raise
NotImplementedError
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
7623091d
...
...
@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache.
import
heapq
import
time
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
import
torch
from
sglang.srt.mem_cache.base_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
class
TreeNode
:
...
...
@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache):
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_req
(
self
,
token_ids
,
last_uncached_pos
,
req_pool_idx
,
del_in_memory_pool
=
True
,
old_last_node
=
None
,
**
kwargs
,
):
# Insert the request into radix cache
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
new_prefix_len
=
self
.
insert
(
token_ids
,
indices
.
clone
())
def
cache_finished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
"""Cache request when it finishes."""
if
token_ids
is
None
:
token_ids
=
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
if
self
.
disable
:
if
del_in_memory_pool
:
self
.
token_to_kv_pool
.
free
(
indices
)
else
:
return
torch
.
tensor
([],
dtype
=
torch
.
int32
),
self
.
root_node
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
return
# Radix Cache takes one ref in memory pool
self
.
token_to_kv_pool
.
free
(
indices
[
last_uncached_pos
:
new_prefix_len
])
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
if
del_in_memory_pool
:
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
else
:
cached_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
assert
len
(
cached_indices
)
==
len
(
token_ids
)
# Remove req slot release the cache lock
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
dec_lock_ref
(
req
.
last_node
)
def
cache_unfinished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
"""Cache request when it is unfinished."""
if
self
.
disable
:
return
if
token_ids
is
None
:
token_ids
=
req
.
input_ids
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
# The prefix indices could be updated, reuse it
new_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
assert
len
(
new_indices
)
==
len
(
token_ids
)
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
last_uncached_pos
:
len
(
cached_indices
)
]
=
cached_indices
[
last_uncached_pos
:]
self
.
dec_lock_ref
(
old_last_node
)
req
.
req_pool_idx
,
len
(
req
.
prefix_indices
)
:
len
(
new_indices
)
]
=
new_indices
[
len
(
req
.
prefix_indices
)
:]
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
inc_lock_ref
(
new_last_node
)
return
cached_indices
,
new_last_node
req
.
prefix_indices
=
new_indices
req
.
last_node
=
new_last_node
def
pretty_print
(
self
):
self
.
_print_helper
(
self
.
root_node
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment