Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7623091d
"tests/vscode:/vscode.git/clone" did not exist on "532d4ac3d9310855ccbf9f13907b5172acf2eca8"
Unverified
Commit
7623091d
authored
Aug 07, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 07, 2024
Browse files
RadixCache method adjust (#977)
parent
f724f1f1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
141 additions
and
119 deletions
+141
-119
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+22
-13
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+37
-60
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+5
-1
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+29
-15
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+48
-30
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
7623091d
...
...
@@ -124,7 +124,7 @@ class Req:
# For vision input
self
.
pixel_values
=
None
self
.
image_size
=
None
self
.
image_offset
=
0
self
.
image_offset
=
None
self
.
pad_value
=
None
# Prefix info
...
...
@@ -162,6 +162,13 @@ class Req:
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
def
adjust_max_prefix_ids
(
self
):
max_prefix_ids
=
self
.
input_ids
if
self
.
return_logprob
:
max_prefix_ids
=
self
.
input_ids
[:
self
.
logprob_start_len
]
return
max_prefix_ids
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def
init_incremental_detokenize
(
self
):
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
...
...
@@ -444,7 +451,8 @@ class ScheduleBatch:
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[
r
.
image_offset
-
p_len
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
(
r
.
image_offset
-
p_len
)
if
r
.
image_offset
is
not
None
else
0
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
]
self
.
prefix_lens
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
extend_num_tokens
=
extend_num_tokens
...
...
@@ -596,15 +604,7 @@ class ScheduleBatch:
req
.
vid
+=
1
# insert the old request into tree_cache
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
cur_all_ids
,
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req
.
req_pool_idx
,
)
# unlock the last node
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
self
.
tree_cache
.
cache_finished_req
(
req
,
cur_all_ids
)
# re-applying image padding
if
req
.
pixel_values
is
not
None
:
...
...
@@ -621,8 +621,7 @@ class ScheduleBatch:
jump_forward_reqs
.
append
(
req
)
filter_indices
.
remove
(
i
)
if
len
(
filter_indices
)
<
len
(
self
.
reqs
):
self
.
filter_batch
(
filter_indices
)
self
.
filter_batch
(
filter_indices
)
return
jump_forward_reqs
...
...
@@ -644,6 +643,15 @@ class ScheduleBatch:
]
=
self
.
out_cache_loc
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
]):
if
unfinished_indices
is
None
or
len
(
unfinished_indices
)
==
0
:
# Filter out all requests
self
.
reqs
=
[]
return
if
len
(
unfinished_indices
)
==
len
(
self
.
reqs
):
# No need to filter
return
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
unfinished_indices
]
new_indices
=
torch
.
tensor
(
unfinished_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
self
.
seq_lens
[
new_indices
]
...
...
@@ -711,6 +719,7 @@ class ScheduleBatch:
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
self
.
temperatures
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
7623091d
...
...
@@ -232,8 +232,6 @@ class ModelTpServer:
if
new_batch
is
not
None
:
# Run a new prefill batch
self
.
forward_prefill_batch
(
new_batch
)
self
.
cache_filled_batch
(
new_batch
)
self
.
filter_out_inflight
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
...
...
@@ -353,26 +351,20 @@ class ModelTpServer:
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
# TODO(lsyin): organize this function
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
if
running_bs
>=
self
.
max_running_requests
:
return
return
None
# Compute matched prefix length
for
req
in
self
.
waiting_queue
:
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
try_match_ids
=
req
.
input_ids
if
req
.
return_logprob
:
try_match_ids
=
req
.
input_ids
[:
req
.
logprob_start_len
]
# NOTE: the prefix_indices must always be aligned with last_node
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
req
.
rid
,
key
=
try_match
_ids
req
.
prefix_indices
,
req
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
req
.
rid
,
key
=
req
.
adjust_max_prefix
_ids
()
)
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
req
.
prefix_indices
=
prefix_indices
req
.
last_node
=
last_node
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
req
.
prefix_indices
)
# Get priority queue
self
.
waiting_queue
=
self
.
scheduler
.
get_priority_queue
(
self
.
waiting_queue
)
...
...
@@ -394,6 +386,24 @@ class ModelTpServer:
)
for
req
in
self
.
waiting_queue
:
# FIXME: Move this code into adjust_max_prefix_len
if
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
:
# Need at least two tokens to compute normalized logprob
if
req
.
extend_input_len
<
2
:
delta
=
2
-
req
.
extend_input_len
req
.
extend_input_len
+=
delta
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
delta
if
req
.
extend_input_len
==
0
and
req
.
sampling_params
.
max_new_tokens
>
0
:
# Need at least one token to compute logits
req
.
extend_input_len
=
1
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
1
res
=
adder
.
add_one_req
(
req
)
if
(
not
res
...
...
@@ -470,10 +480,20 @@ class ModelTpServer:
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
.
return_logprob
:
self
.
add_logprob_return_values
(
i
,
req
,
pt
,
next_token_ids
,
output
)
pt
+=
req
.
extend_input_len
...
...
@@ -529,22 +549,6 @@ class ModelTpServer:
)
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
def
cache_filled_batch
(
self
,
batch
:
ScheduleBatch
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
tuple
(
req
.
input_ids
),
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req
.
req_pool_idx
,
del_in_memory_pool
=
False
,
old_last_node
=
req
.
last_node
,
)
req
.
prefix_indices
,
req
.
last_node
=
new_prefix_indices
,
new_last_node
if
req
is
self
.
current_inflight_req
:
# inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
...
...
@@ -595,6 +599,9 @@ class ModelTpServer:
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
...
...
@@ -614,12 +621,9 @@ class ModelTpServer:
output_spaces_between_special_tokens
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
finished_indices
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
finished
():
finished_indices
.
append
(
i
)
else
:
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
unfinished_indices
.
append
(
i
)
if
req
.
finished
()
or
(
...
...
@@ -683,34 +687,7 @@ class ModelTpServer:
)
)
# Remove finished reqs
if
finished_indices
:
# Update radix cache
for
i
in
finished_indices
:
req
=
batch
.
reqs
[
i
]
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
tuple
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req
.
req_pool_idx
,
)
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
# Update batch tensors
if
unfinished_indices
:
batch
.
filter_batch
(
unfinished_indices
)
else
:
batch
.
reqs
=
[]
def
filter_out_inflight
(
self
,
batch
:
ScheduleBatch
):
# TODO(lsyin): reduce the overhead, make a special version for this
if
self
.
current_inflight_req
is
None
:
return
to_remove
=
batch
.
reqs
.
index
(
self
.
current_inflight_req
)
unfinished_indices
=
[
i
for
i
in
range
(
len
(
batch
.
reqs
))
if
i
!=
to_remove
]
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
def
flush_cache
(
self
):
...
...
python/sglang/srt/mem_cache/base_cache.py
→
python/sglang/srt/mem_cache/base_
prefix_
cache.py
View file @
7623091d
...
...
@@ -17,7 +17,11 @@ class BasePrefixCache(ABC):
pass
@
abstractmethod
def
cache_req
(
self
,
**
kwargs
):
def
cache_finished_req
(
self
,
**
kwargs
):
pass
@
abstractmethod
def
cache_unfinished_req
(
self
,
**
kwargs
):
pass
@
abstractmethod
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
7623091d
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
sglang.srt.mem_cache.base_cache
import
BasePrefixCache
from
typing
import
TYPE_CHECKING
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
class
ChunkCacheEntry
:
...
...
@@ -27,22 +32,31 @@ class ChunkCache(BasePrefixCache):
entry
=
self
.
entries
[
rid
]
return
entry
.
value
,
entry
def
cache_req
(
self
,
rid
,
token_ids
,
req_pool_idx
,
del_in_memory_pool
=
True
,
**
kwargs
):
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
if
del_in_memory_pool
:
assert
rid
in
self
.
entries
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
indices
)
return
def
cache_finished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
if
token_ids
is
None
:
token_ids
=
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
if
rid
not
in
self
.
entries
:
self
.
entries
[
rid
]
=
ChunkCacheEntry
(
rid
,
indices
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
assert
req
.
rid
in
self
.
entries
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
entry
=
self
.
entries
[
rid
]
entry
.
value
=
indices
return
indices
,
entry
def
cache_unfinished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
if
token_ids
is
None
:
token_ids
=
req
.
input_ids
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
if
req
.
rid
not
in
self
.
entries
:
self
.
entries
[
req
.
rid
]
=
ChunkCacheEntry
(
req
.
rid
,
kv_indices
)
entry
=
self
.
entries
[
req
.
rid
]
entry
.
value
=
kv_indices
return
kv_indices
,
entry
def
insert
(
self
):
raise
NotImplementedError
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
7623091d
...
...
@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache.
import
heapq
import
time
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
import
torch
from
sglang.srt.mem_cache.base_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
class
TreeNode
:
...
...
@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache):
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_req
(
self
,
token_ids
,
last_uncached_pos
,
req_pool_idx
,
del_in_memory_pool
=
True
,
old_last_node
=
None
,
**
kwargs
,
):
# Insert the request into radix cache
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
new_prefix_len
=
self
.
insert
(
token_ids
,
indices
.
clone
())
def
cache_finished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
"""Cache request when it finishes."""
if
token_ids
is
None
:
token_ids
=
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
if
self
.
disable
:
if
del_in_memory_pool
:
self
.
token_to_kv_pool
.
free
(
indices
)
else
:
return
torch
.
tensor
([],
dtype
=
torch
.
int32
),
self
.
root_node
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
return
# Radix Cache takes one ref in memory pool
self
.
token_to_kv_pool
.
free
(
indices
[
last_uncached_pos
:
new_prefix_len
])
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
if
del_in_memory_pool
:
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
else
:
cached_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
assert
len
(
cached_indices
)
==
len
(
token_ids
)
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
last_uncached_pos
:
len
(
cached_indices
)
]
=
cached_indices
[
last_uncached_pos
:]
self
.
dec_lock_ref
(
old_last_node
)
self
.
inc_lock_ref
(
new_last_node
)
return
cached_indices
,
new_last_node
# Remove req slot release the cache lock
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
dec_lock_ref
(
req
.
last_node
)
def
cache_unfinished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
"""Cache request when it is unfinished."""
if
self
.
disable
:
return
if
token_ids
is
None
:
token_ids
=
req
.
input_ids
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
# The prefix indices could be updated, reuse it
new_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
assert
len
(
new_indices
)
==
len
(
token_ids
)
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
len
(
req
.
prefix_indices
)
:
len
(
new_indices
)
]
=
new_indices
[
len
(
req
.
prefix_indices
)
:]
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
inc_lock_ref
(
new_last_node
)
req
.
prefix_indices
=
new_indices
req
.
last_node
=
new_last_node
def
pretty_print
(
self
):
self
.
_print_helper
(
self
.
root_node
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment