Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9208618b
Unverified
Commit
9208618b
authored
Dec 11, 2024
by
SangBin Cho
Committed by
GitHub
Dec 11, 2024
Browse files
[Core] in batch prefix caching by delay scheduling (#2442)
parent
864bf2ba
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
87 additions
and
16 deletions
+87
-16
python/sglang/lang/backend/runtime_endpoint.py
python/sglang/lang/backend/runtime_endpoint.py
+1
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-0
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+58
-7
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+2
-2
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+2
-2
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+12
-2
python/sglang/utils.py
python/sglang/utils.py
+9
-2
No files found.
python/sglang/lang/backend/runtime_endpoint.py
View file @
9208618b
...
...
@@ -55,6 +55,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/flush_cache"
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
method
=
"POST"
,
)
self
.
_assert_success
(
res
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
9208618b
...
...
@@ -256,6 +256,7 @@ class Req:
# Prefix info
self
.
prefix_indices
=
[]
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
self
.
extend_input_len
=
0
self
.
last_node
=
None
...
...
@@ -316,6 +317,7 @@ class Req:
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
if
tree_cache
is
not
None
:
# tree cache is None if the prefix is not computed with tree cache.
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
9208618b
...
...
@@ -20,9 +20,11 @@ from contextlib import contextmanager
from
enum
import
Enum
,
auto
from
typing
import
Dict
,
List
,
Optional
import
torch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.radix_cache
import
TreeNode
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
...
...
@@ -32,6 +34,13 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION"
,
"4096"
)
)
# The threshold to apply in-batch prefix caching.
# If we use too small value, in-batch prefix caching cannot be used. E.g.,
# imagine "the" prefix.
IN_BATCH_PREFIX_CACHING_THRESHOLD
=
int
(
os
.
environ
.
get
(
"SGLANG_IN_BATCH_PREFIX_CACHING_THRESHOLD"
,
"32"
)
)
class
SchedulePolicy
:
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
...
...
@@ -51,18 +60,50 @@ class SchedulePolicy:
# Compute matched prefix length
prefix_computed
=
False
# rid to deprioritize in the current run.
temporary_deprioritized
=
{}
if
policy
==
"lpm"
or
policy
==
"dfs-weight"
:
# It is used to find the matching prefix for in-batch prefix caching.
temp_radix
=
RadixCache
(
None
,
None
,
False
)
for
r
in
waiting_queue
:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
r
.
adjust_max_
prefix_ids
()
rid
=
r
.
rid
,
key
=
prefix_ids
)
# NOTE(sang): This logic is for In-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine "the").
if
len
(
r
.
prefix_indices
)
<=
IN_BATCH_PREFIX_CACHING_THRESHOLD
:
in_batch_matching_prefixes
,
_
=
temp_radix
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
if
(
len
(
in_batch_matching_prefixes
)
>=
IN_BATCH_PREFIX_CACHING_THRESHOLD
):
temporary_deprioritized
[
r
.
rid
]
=
r
else
:
temp_radix
.
insert
(
prefix_ids
,
torch
.
tensor
(
prefix_ids
))
prefix_computed
=
True
if
policy
==
"lpm"
:
# Longest Prefix Match
waiting_queue
.
sort
(
key
=
lambda
x
:
-
len
(
x
.
prefix_indices
))
def
get_priority
(
r
:
Req
):
score
=
0
if
r
.
rid
in
temporary_deprioritized
:
score
=
float
(
"inf"
)
else
:
score
=
-
len
(
r
.
prefix_indices
)
return
score
waiting_queue
.
sort
(
key
=
get_priority
)
elif
policy
==
"fcfs"
:
# first come first serve
pass
...
...
@@ -76,6 +117,7 @@ class SchedulePolicy:
for
req
in
waiting_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
# node -> # of requests for that node.
node_to_weight
=
defaultdict
(
int
)
for
node
in
last_node_to_reqs
:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
...
...
@@ -87,7 +129,9 @@ class SchedulePolicy:
node_to_weight
,
last_node_to_reqs
,
waiting_queue
,
temporary_deprioritized
,
)
waiting_queue
.
extend
(
temporary_deprioritized
.
values
())
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
...
...
@@ -101,15 +145,22 @@ class SchedulePolicy:
def
get_dfs_priority
(
self
,
cur_node
:
TreeNode
,
node_to_priority
:
Dict
,
last_node_to_reqs
:
Dict
,
node_to_priority
:
Dict
[
TreeNode
,
int
]
,
last_node_to_reqs
:
Dict
[
TreeNode
,
List
[
Req
]]
,
q
:
List
,
temporary_deprioritized
:
Dict
[
str
,
Req
],
):
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
self
.
get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
q
.
extend
(
last_node_to_reqs
[
cur_node
])
self
.
get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
,
temporary_deprioritized
)
for
req
in
last_node_to_reqs
[
cur_node
]:
if
req
.
rid
in
temporary_deprioritized
:
continue
q
.
append
(
req
)
class
AddReqResult
(
Enum
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
9208618b
...
...
@@ -713,7 +713,7 @@ class Scheduler:
if
crash_on_warnings
():
raise
ValueError
(
msg
)
def
get_next_batch_to_run
(
self
):
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]
:
# Merge the prefill batch into the running batch
if
self
.
last_batch
and
self
.
last_batch
.
forward_mode
.
is_extend
():
if
self
.
being_chunked_req
:
...
...
python/sglang/srt/mem_cache/base_prefix_cache.py
View file @
9208618b
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
from
typing
import
Callable
,
List
,
Tuple
class
BasePrefixCache
(
ABC
):
...
...
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
pass
@
abstractmethod
def
match_prefix
(
self
,
**
kwargs
):
def
match_prefix
(
self
,
**
kwargs
)
->
Tuple
[
List
[
int
],
int
]
:
pass
@
abstractmethod
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
9208618b
...
...
@@ -2,7 +2,7 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Tuple
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
...
...
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
def
reset
(
self
):
self
.
entries
=
{}
def
match_prefix
(
self
,
rid
:
int
,
key
:
List
[
int
]):
def
match_prefix
(
self
,
rid
:
int
,
key
:
List
[
int
])
->
Tuple
[
List
[
int
],
int
]
:
if
rid
not
in
self
.
entries
:
return
[],
None
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
9208618b
...
...
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import
heapq
import
time
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
self
.
root_node
.
lock_ref
=
1
self
.
evictable_size_
=
0
def
match_prefix
(
self
,
key
:
List
,
**
kwargs
):
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
int
]:
"""Find the matching prefix from the radix tree.
Args:
key: A list of token IDs to find a matching prefix.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if
self
.
disable
:
return
[],
self
.
root_node
...
...
python/sglang/utils.py
View file @
9208618b
...
...
@@ -79,7 +79,14 @@ class HttpResponse:
return
self
.
resp
.
status
def
http_request
(
url
,
json
=
None
,
stream
=
False
,
api_key
=
None
,
verify
=
None
):
def
http_request
(
url
,
json
=
None
,
stream
=
False
,
api_key
=
None
,
verify
=
None
,
method
:
Optional
[
str
]
=
None
,
):
"""A faster version of requests.post with low-level urllib API."""
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
}
...
...
@@ -90,7 +97,7 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None):
if
stream
:
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
)
else
:
req
=
urllib
.
request
.
Request
(
url
,
headers
=
headers
)
req
=
urllib
.
request
.
Request
(
url
,
headers
=
headers
,
method
=
method
)
if
json
is
None
:
data
=
None
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment