Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9208618b
Unverified
Commit
9208618b
authored
Dec 11, 2024
by
SangBin Cho
Committed by
GitHub
Dec 11, 2024
Browse files
[Core] in batch prefix caching by delay scheduling (#2442)
parent
864bf2ba
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
87 additions
and
16 deletions
+87
-16
python/sglang/lang/backend/runtime_endpoint.py
python/sglang/lang/backend/runtime_endpoint.py
+1
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-0
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+58
-7
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+2
-2
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+2
-2
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+12
-2
python/sglang/utils.py
python/sglang/utils.py
+9
-2
No files found.
python/sglang/lang/backend/runtime_endpoint.py
View file @
9208618b
...
@@ -55,6 +55,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -55,6 +55,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/flush_cache"
,
self
.
base_url
+
"/flush_cache"
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
method
=
"POST"
,
)
)
self
.
_assert_success
(
res
)
self
.
_assert_success
(
res
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
9208618b
...
@@ -256,6 +256,7 @@ class Req:
...
@@ -256,6 +256,7 @@ class Req:
# Prefix info
# Prefix info
self
.
prefix_indices
=
[]
self
.
prefix_indices
=
[]
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
self
.
extend_input_len
=
0
self
.
extend_input_len
=
0
self
.
last_node
=
None
self
.
last_node
=
None
...
@@ -316,6 +317,7 @@ class Req:
...
@@ -316,6 +317,7 @@ class Req:
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
if
tree_cache
is
not
None
:
if
tree_cache
is
not
None
:
# tree cache is None if the prefix is not computed with tree cache.
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
)
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
9208618b
...
@@ -20,9 +20,11 @@ from contextlib import contextmanager
...
@@ -20,9 +20,11 @@ from contextlib import contextmanager
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
import
torch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.radix_cache
import
TreeNode
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
# This can prevent the server from being too conservative.
...
@@ -32,6 +34,13 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
...
@@ -32,6 +34,13 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION"
,
"4096"
)
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION"
,
"4096"
)
)
)
# The threshold to apply in-batch prefix caching.
# If we use too small value, in-batch prefix caching cannot be used. E.g.,
# imagine "the" prefix.
IN_BATCH_PREFIX_CACHING_THRESHOLD
=
int
(
os
.
environ
.
get
(
"SGLANG_IN_BATCH_PREFIX_CACHING_THRESHOLD"
,
"32"
)
)
class
SchedulePolicy
:
class
SchedulePolicy
:
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
...
@@ -51,18 +60,50 @@ class SchedulePolicy:
...
@@ -51,18 +60,50 @@ class SchedulePolicy:
# Compute matched prefix length
# Compute matched prefix length
prefix_computed
=
False
prefix_computed
=
False
# rid to deprioritize in the current run.
temporary_deprioritized
=
{}
if
policy
==
"lpm"
or
policy
==
"dfs-weight"
:
if
policy
==
"lpm"
or
policy
==
"dfs-weight"
:
# It is used to find the matching prefix for in-batch prefix caching.
temp_radix
=
RadixCache
(
None
,
None
,
False
)
for
r
in
waiting_queue
:
for
r
in
waiting_queue
:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
r
.
adjust_max_
prefix_ids
()
rid
=
r
.
rid
,
key
=
prefix_ids
)
)
# NOTE(sang): This logic is for In-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine "the").
if
len
(
r
.
prefix_indices
)
<=
IN_BATCH_PREFIX_CACHING_THRESHOLD
:
in_batch_matching_prefixes
,
_
=
temp_radix
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
if
(
len
(
in_batch_matching_prefixes
)
>=
IN_BATCH_PREFIX_CACHING_THRESHOLD
):
temporary_deprioritized
[
r
.
rid
]
=
r
else
:
temp_radix
.
insert
(
prefix_ids
,
torch
.
tensor
(
prefix_ids
))
prefix_computed
=
True
prefix_computed
=
True
if
policy
==
"lpm"
:
if
policy
==
"lpm"
:
# Longest Prefix Match
# Longest Prefix Match
waiting_queue
.
sort
(
key
=
lambda
x
:
-
len
(
x
.
prefix_indices
))
def
get_priority
(
r
:
Req
):
score
=
0
if
r
.
rid
in
temporary_deprioritized
:
score
=
float
(
"inf"
)
else
:
score
=
-
len
(
r
.
prefix_indices
)
return
score
waiting_queue
.
sort
(
key
=
get_priority
)
elif
policy
==
"fcfs"
:
elif
policy
==
"fcfs"
:
# first come first serve
# first come first serve
pass
pass
...
@@ -76,6 +117,7 @@ class SchedulePolicy:
...
@@ -76,6 +117,7 @@ class SchedulePolicy:
for
req
in
waiting_queue
:
for
req
in
waiting_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
# node -> # of requests for that node.
node_to_weight
=
defaultdict
(
int
)
node_to_weight
=
defaultdict
(
int
)
for
node
in
last_node_to_reqs
:
for
node
in
last_node_to_reqs
:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
...
@@ -87,7 +129,9 @@ class SchedulePolicy:
...
@@ -87,7 +129,9 @@ class SchedulePolicy:
node_to_weight
,
node_to_weight
,
last_node_to_reqs
,
last_node_to_reqs
,
waiting_queue
,
waiting_queue
,
temporary_deprioritized
,
)
)
waiting_queue
.
extend
(
temporary_deprioritized
.
values
())
else
:
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
...
@@ -101,15 +145,22 @@ class SchedulePolicy:
...
@@ -101,15 +145,22 @@ class SchedulePolicy:
def
get_dfs_priority
(
def
get_dfs_priority
(
self
,
self
,
cur_node
:
TreeNode
,
cur_node
:
TreeNode
,
node_to_priority
:
Dict
,
node_to_priority
:
Dict
[
TreeNode
,
int
]
,
last_node_to_reqs
:
Dict
,
last_node_to_reqs
:
Dict
[
TreeNode
,
List
[
Req
]]
,
q
:
List
,
q
:
List
,
temporary_deprioritized
:
Dict
[
str
,
Req
],
):
):
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
for
child
in
childs
:
self
.
get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
self
.
get_dfs_priority
(
q
.
extend
(
last_node_to_reqs
[
cur_node
])
child
,
node_to_priority
,
last_node_to_reqs
,
q
,
temporary_deprioritized
)
for
req
in
last_node_to_reqs
[
cur_node
]:
if
req
.
rid
in
temporary_deprioritized
:
continue
q
.
append
(
req
)
class
AddReqResult
(
Enum
):
class
AddReqResult
(
Enum
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
9208618b
...
@@ -713,7 +713,7 @@ class Scheduler:
...
@@ -713,7 +713,7 @@ class Scheduler:
if
crash_on_warnings
():
if
crash_on_warnings
():
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
def
get_next_batch_to_run
(
self
):
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]
:
# Merge the prefill batch into the running batch
# Merge the prefill batch into the running batch
if
self
.
last_batch
and
self
.
last_batch
.
forward_mode
.
is_extend
():
if
self
.
last_batch
and
self
.
last_batch
.
forward_mode
.
is_extend
():
if
self
.
being_chunked_req
:
if
self
.
being_chunked_req
:
...
...
python/sglang/srt/mem_cache/base_prefix_cache.py
View file @
9208618b
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
from
typing
import
Callable
,
List
,
Tuple
class
BasePrefixCache
(
ABC
):
class
BasePrefixCache
(
ABC
):
...
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
...
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
match_prefix
(
self
,
**
kwargs
):
def
match_prefix
(
self
,
**
kwargs
)
->
Tuple
[
List
[
int
],
int
]
:
pass
pass
@
abstractmethod
@
abstractmethod
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
9208618b
...
@@ -2,7 +2,7 @@ from __future__ import annotations
...
@@ -2,7 +2,7 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Tuple
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
...
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
...
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
def
reset
(
self
):
def
reset
(
self
):
self
.
entries
=
{}
self
.
entries
=
{}
def
match_prefix
(
self
,
rid
:
int
,
key
:
List
[
int
]):
def
match_prefix
(
self
,
rid
:
int
,
key
:
List
[
int
])
->
Tuple
[
List
[
int
],
int
]
:
if
rid
not
in
self
.
entries
:
if
rid
not
in
self
.
entries
:
return
[],
None
return
[],
None
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
9208618b
...
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
...
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import
heapq
import
heapq
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
...
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
self
.
root_node
.
lock_ref
=
1
self
.
root_node
.
lock_ref
=
1
self
.
evictable_size_
=
0
self
.
evictable_size_
=
0
def
match_prefix
(
self
,
key
:
List
,
**
kwargs
):
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
int
]:
"""Find the matching prefix from the radix tree.
Args:
key: A list of token IDs to find a matching prefix.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if
self
.
disable
:
if
self
.
disable
:
return
[],
self
.
root_node
return
[],
self
.
root_node
...
...
python/sglang/utils.py
View file @
9208618b
...
@@ -79,7 +79,14 @@ class HttpResponse:
...
@@ -79,7 +79,14 @@ class HttpResponse:
return
self
.
resp
.
status
return
self
.
resp
.
status
def
http_request
(
url
,
json
=
None
,
stream
=
False
,
api_key
=
None
,
verify
=
None
):
def
http_request
(
url
,
json
=
None
,
stream
=
False
,
api_key
=
None
,
verify
=
None
,
method
:
Optional
[
str
]
=
None
,
):
"""A faster version of requests.post with low-level urllib API."""
"""A faster version of requests.post with low-level urllib API."""
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
}
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
}
...
@@ -90,7 +97,7 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None):
...
@@ -90,7 +97,7 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None):
if
stream
:
if
stream
:
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
)
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
)
else
:
else
:
req
=
urllib
.
request
.
Request
(
url
,
headers
=
headers
)
req
=
urllib
.
request
.
Request
(
url
,
headers
=
headers
,
method
=
method
)
if
json
is
None
:
if
json
is
None
:
data
=
None
data
=
None
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment