Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
39191c85
You need to sign in or sign up before continuing.
Unverified
Commit
39191c85
authored
May 13, 2024
by
Liangsheng Yin
Committed by
GitHub
May 13, 2024
Browse files
Cache optimizations (#418)
parent
562b8857
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
117 additions
and
96 deletions
+117
-96
python/sglang/global_config.py
python/sglang/global_config.py
+3
-0
python/sglang/srt/backend_config.py
python/sglang/srt/backend_config.py
+0
-13
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+17
-19
python/sglang/srt/managers/router/manager.py
python/sglang/srt/managers/router/manager.py
+4
-4
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+27
-14
python/sglang/srt/managers/router/radix_cache.py
python/sglang/srt/managers/router/radix_cache.py
+47
-17
python/sglang/srt/managers/router/scheduler.py
python/sglang/srt/managers/router/scheduler.py
+17
-28
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-1
No files found.
python/sglang/global_config.py
View file @
39191c85
...
...
@@ -25,5 +25,8 @@ class GlobalConfig:
# adjust_cache: Adjust the position embedding of KV cache.
self
.
concate_and_append_mode
=
"no_adjust"
# Request dependency time due to network delay
self
.
request_dependency_time
=
0.03
global_config
=
GlobalConfig
()
python/sglang/srt/backend_config.py
deleted
100644 → 0
View file @
562b8857
"""
Backend configurations, may vary with different serving platforms.
"""
from
dataclasses
import
dataclass
@
dataclass
class
BackendConfig
:
extend_dependency_time
:
float
=
0.03
GLOBAL_BACKEND_CONFIG
=
BackendConfig
()
python/sglang/srt/managers/router/infer_batch.py
View file @
39191c85
...
...
@@ -335,20 +335,20 @@ class Batch:
req
=
self
.
reqs
[
idx
]
retracted_reqs
.
append
(
req
)
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
# TODO: apply more fine-grained retraction
last_uncached_pos
=
len
(
req
.
prefix_indices
)
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
idx
]
][
last_uncached_pos
:
seq_lens_cpu
[
idx
]]
self
.
token_to_kv_pool
.
dec_refs
(
token_indices
)
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
req
.
prefix_indices
=
None
req
.
last_node
=
None
req
.
extend_input_len
=
0
req
.
output_ids
=
[]
req
.
regex_fsm_state
=
0
# TODO: apply more fine-grained retraction
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
idx
]
][:
seq_lens_cpu
[
idx
]]
self
.
token_to_kv_pool
.
dec_refs
(
token_indices
)
self
.
filter_batch
(
sorted_indices
)
return
retracted_reqs
...
...
@@ -367,20 +367,18 @@ class Batch:
if
len
(
jump_forward_str
)
<=
1
:
continue
# insert the old request into tree_cache
token_ids_in_memory
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
if
req_pool_indices_cpu
is
None
:
req_pool_indices_cpu
=
self
.
req_pool_indices
.
tolist
()
req_pool_idx
=
req_pool_indices_cpu
[
i
]
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids_in_memory
)
]
prefix_len
=
self
.
tree_cache
.
insert
(
token_ids_in_memory
,
indices
.
clone
()
# insert the old request into tree_cache
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
)
self
.
token_to_kv_pool
.
dec_refs
(
indices
[:
prefix_len
])
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
tree_cache
.
dec_
ref_counter
(
req
.
last_node
)
# unlock the last node
self
.
tree_cache
.
dec_
lock_ref
(
req
.
last_node
)
# jump-forward
req
.
jump_forward_and_retokenize
(
jump_forward_str
,
next_state
)
...
...
python/sglang/srt/managers/router/manager.py
View file @
39191c85
...
...
@@ -5,7 +5,7 @@ import uvloop
import
zmq
import
zmq.asyncio
from
sglang
.srt.backend_config
import
GLOBAL_BACKEND_CONFIG
from
sglang
import
global_config
from
sglang.srt.managers.router.model_rpc
import
ModelRpcClient
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
get_exception_traceback
...
...
@@ -30,7 +30,7 @@ class RouterManager:
self
.
recv_reqs
=
[]
# Init some configs
self
.
extend
_dependency_time
=
GLOBAL_BACKEND_CONFIG
.
extend
_dependency_time
self
.
request
_dependency_time
=
global_config
.
request
_dependency_time
async
def
loop_for_forward
(
self
):
while
True
:
...
...
@@ -46,9 +46,9 @@ class RouterManager:
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished
for
obj
in
out_pyobjs
])
if
has_finished
:
if
self
.
extend
_dependency_time
>
0
:
if
self
.
request
_dependency_time
>
0
:
slept
=
True
await
asyncio
.
sleep
(
self
.
extend
_dependency_time
)
await
asyncio
.
sleep
(
self
.
request
_dependency_time
)
if
not
slept
:
await
asyncio
.
sleep
(
0.0006
)
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
39191c85
...
...
@@ -117,7 +117,11 @@ class ModelRpcServer:
logger
.
info
(
f
"server_args:
{
server_args
.
print_mode_args
()
}
"
)
# Init cache
self
.
tree_cache
=
RadixCache
(
disable
=
server_args
.
disable_radix_cache
)
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
disable
=
server_args
.
disable_radix_cache
,
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
Scheduler
(
self
.
schedule_heuristic
,
...
...
@@ -203,6 +207,8 @@ class ModelRpcServer:
# Run new fill batch
self
.
forward_fill_batch
(
new_batch
)
self
.
cache_filled_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
...
...
@@ -349,20 +355,19 @@ class ModelRpcServer:
and
req
.
extend_input_len
+
new_batch_input_tokens
<
self
.
max_prefill_num_token
):
delta
=
self
.
tree_cache
.
inc_
ref_counter
(
req
.
last_node
)
delta
=
self
.
tree_cache
.
inc_
lock_ref
(
req
.
last_node
)
available_size
+=
delta
if
not
(
req
.
extend_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
):
# Undo
the insertion
delta
=
self
.
tree_cache
.
dec_
ref_counter
(
req
.
last_node
)
# Undo
locking
delta
=
self
.
tree_cache
.
dec_
lock_ref
(
req
.
last_node
)
available_size
+=
delta
break
else
:
# Add this request to the running batch
self
.
token_to_kv_pool
.
add_refs
(
req
.
prefix_indices
)
can_run_list
.
append
(
req
)
new_batch_total_tokens
+=
(
req
.
extend_input_len
+
req
.
max_new_tokens
()
...
...
@@ -477,6 +482,18 @@ class ModelRpcServer:
self
.
handle_finished_requests
(
batch
)
def
cache_filled_batch
(
self
,
batch
:
Batch
):
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
tolist
()
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
del_in_memory_pool
=
False
,
old_last_node
=
req
.
last_node
,
)
req
.
prefix_indices
,
req
.
last_node
=
new_prefix_indices
,
new_last_node
def
forward_decode_batch
(
self
,
batch
:
Batch
):
# check if decode out of memory
if
not
batch
.
check_decode_mem
():
...
...
@@ -636,17 +653,13 @@ class ModelRpcServer:
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
tolist
()
for
i
in
finished_indices
:
req
=
batch
.
reqs
[
i
]
req_pool_idx
=
req_pool_indices_cpu
[
i
]
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)
seq_len
=
len
(
token_ids
)
-
1
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
seq_len
]
prefix_len
=
self
.
tree_cache
.
insert
(
token_ids
[:
seq_len
],
indices
.
clone
()
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
)
self
.
token_to_kv_pool
.
dec_refs
(
indices
[:
prefix_len
])
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
# Update batch tensors
if
unfinished_indices
:
...
...
python/sglang/srt/managers/router/radix_cache.py
View file @
39191c85
...
...
@@ -11,7 +11,7 @@ class TreeNode:
self
.
parent
=
None
self
.
key
=
None
self
.
value
=
None
self
.
ref_counter
=
0
self
.
lock_ref
=
0
self
.
last_access_time
=
time
.
time
()
def
__lt__
(
self
,
other
:
"TreeNode"
):
...
...
@@ -28,7 +28,9 @@ def _key_match(key0, key1):
class
RadixCache
:
def
__init__
(
self
,
disable
:
bool
=
False
):
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
,
disable
:
bool
=
False
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
disable
=
disable
self
.
reset
()
...
...
@@ -38,7 +40,7 @@ class RadixCache:
self
.
root_node
=
TreeNode
()
self
.
root_node
.
key
=
[]
self
.
root_node
.
value
=
[]
self
.
root_node
.
ref_counter
=
1
self
.
root_node
.
lock_ref
=
1
self
.
evictable_size_
=
0
def
match_prefix
(
self
,
key
):
...
...
@@ -50,6 +52,8 @@ class RadixCache:
self
.
_match_prefix_helper
(
self
.
root_node
,
key
,
value
,
last_node
)
if
value
:
value
=
torch
.
concat
(
value
)
else
:
value
=
torch
.
tensor
([],
dtype
=
torch
.
int64
)
return
value
,
last_node
[
0
]
def
insert
(
self
,
key
,
value
=
None
):
...
...
@@ -60,6 +64,34 @@ class RadixCache:
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_req
(
self
,
token_ids
,
last_uncached_pos
,
req_pool_idx
,
del_in_memory_pool
=
True
,
old_last_node
=
None
,
):
# Insert the request into radix cache
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
new_prefix_len
=
self
.
insert
(
token_ids
,
indices
.
clone
())
# Radix Cache takes one ref in memory pool
self
.
token_to_kv_pool
.
dec_refs
(
indices
[
last_uncached_pos
:
new_prefix_len
])
if
del_in_memory_pool
:
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
else
:
cached_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
assert
len
(
cached_indices
)
==
len
(
token_ids
)
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
last_uncached_pos
:
len
(
cached_indices
)
]
=
cached_indices
[
last_uncached_pos
:]
self
.
dec_lock_ref
(
old_last_node
)
self
.
inc_lock_ref
(
new_last_node
)
return
cached_indices
,
new_last_node
def
pretty_print
(
self
):
self
.
_print_helper
(
self
.
root_node
,
0
)
print
(
f
"#tokens:
{
self
.
total_size
()
}
"
)
...
...
@@ -80,7 +112,7 @@ class RadixCache:
if
x
==
self
.
root_node
:
break
if
x
.
ref_counter
>
0
:
if
x
.
lock_ref
>
0
:
continue
num_evicted
+=
evict_callback
(
x
.
value
)
...
...
@@ -89,23 +121,23 @@ class RadixCache:
if
len
(
x
.
parent
.
children
)
==
0
:
heapq
.
heappush
(
leaves
,
x
.
parent
)
def
inc_
ref_counter
(
self
,
n
ode
):
def
inc_
lock_ref
(
self
,
node
:
TreeN
ode
):
delta
=
0
while
node
!=
self
.
root_node
:
if
node
.
ref_counter
==
0
:
if
node
.
lock_ref
==
0
:
self
.
evictable_size_
-=
len
(
node
.
value
)
delta
-=
len
(
node
.
value
)
node
.
ref_counter
+=
1
node
.
lock_ref
+=
1
node
=
node
.
parent
return
delta
def
dec_
ref_counter
(
self
,
n
ode
):
def
dec_
lock_ref
(
self
,
node
:
TreeN
ode
):
delta
=
0
while
node
!=
self
.
root_node
:
if
node
.
ref_counter
==
1
:
if
node
.
lock_ref
==
1
:
self
.
evictable_size_
+=
len
(
node
.
value
)
delta
+=
len
(
node
.
value
)
node
.
ref_counter
-=
1
node
.
lock_ref
-=
1
node
=
node
.
parent
return
delta
...
...
@@ -131,12 +163,12 @@ class RadixCache:
last_node
[
0
]
=
child
self
.
_match_prefix_helper
(
child
,
key
[
prefix_len
:],
value
,
last_node
)
def
_split_node
(
self
,
key
,
child
,
split_len
):
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
):
# new_node -> child
new_node
=
TreeNode
()
new_node
.
children
=
{
key
[
split_len
:][
0
]:
child
}
new_node
.
parent
=
child
.
parent
new_node
.
ref_counter
=
child
.
ref_counter
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
key
=
child
.
key
[:
split_len
]
new_node
.
value
=
child
.
value
[:
split_len
]
child
.
parent
=
new_node
...
...
@@ -176,11 +208,9 @@ class RadixCache:
self
.
evictable_size_
+=
len
(
value
)
return
0
def
_print_helper
(
self
,
node
,
indent
):
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
):
for
_
,
child
in
node
.
children
.
items
():
print
(
" "
*
indent
,
len
(
child
.
key
),
child
.
key
[:
10
],
f
"r=
{
child
.
ref_counter
}
"
)
print
(
" "
*
indent
,
len
(
child
.
key
),
child
.
key
[:
10
],
f
"r=
{
child
.
lock_ref
}
"
)
self
.
_print_helper
(
child
,
indent
=
indent
+
2
)
def
_delete_leaf
(
self
,
node
):
...
...
@@ -211,7 +241,7 @@ class RadixCache:
if
__name__
==
"__main__"
:
tree
=
RadixCache
()
tree
=
RadixCache
(
None
,
None
,
False
)
tree
.
insert
(
"Hello"
)
tree
.
insert
(
"Hello"
)
...
...
python/sglang/srt/managers/router/scheduler.py
View file @
39191c85
...
...
@@ -27,44 +27,33 @@ class Scheduler:
return
forward_queue
elif
self
.
schedule_heuristic
==
"fcfs"
:
return
forward_queue
elif
self
.
schedule_heuristic
==
"weight"
:
elif
self
.
schedule_heuristic
==
"
dfs-
weight"
:
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
forward_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
for
node
in
last_node_to_reqs
:
last_node_to_reqs
[
node
].
sort
(
key
=
lambda
x
:
-
len
(
x
.
prefix_indices
))
node_to_weight
=
defaultdict
(
int
)
self
.
_calc_weight_recursive
(
self
.
tree_cache
.
root_node
,
last_node_to_reqs
,
node
_to_weight
)
for
node
in
last_node_to_reqs
:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
self
.
calc_weight
(
self
.
tree_cache
.
root_node
,
node_to_weight
)
tmp_queue
=
[]
self
.
_
get_
weight
_priority
_recursive
(
self
.
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
tmp_queue
q
=
[]
self
.
get_
dfs
_priority
(
self
.
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
q
)
assert
len
(
tmp_queue
)
==
len
(
forward_queue
)
return
tmp_queue
assert
len
(
q
)
==
len
(
forward_queue
)
return
q
else
:
raise
ValueError
(
f
"Unknown schedule_heuristic:
{
self
.
schedule_heuristic
}
"
)
def
_calc_weight_recursive
(
self
,
cur_node
,
last_node_to_reqs
,
node_to_weight
):
node_to_weight
[
cur_node
]
=
1
if
cur_node
in
last_node_to_reqs
:
node_to_weight
[
cur_node
]
+=
len
(
last_node_to_reqs
[
cur_node
])
def
calc_weight
(
self
,
cur_node
,
node_to_weight
):
for
child
in
cur_node
.
children
.
values
():
self
.
_
calc_weight
_recursive
(
child
,
last_node_to_reqs
,
node_to_weight
)
self
.
calc_weight
(
child
,
node_to_weight
)
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
def
_get_weight_priority_recursive
(
self
,
cur_node
,
node_to_wight
,
last_node_to_reqs
,
tmp_queue
):
visit_list
=
[
child
for
child
in
cur_node
.
children
.
values
()]
visit_list
.
sort
(
key
=
lambda
x
:
-
node_to_wight
[
x
])
# for node in visit_list:
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
for
child
in
visit_list
:
self
.
_get_weight_priority_recursive
(
child
,
node_to_wight
,
last_node_to_reqs
,
tmp_queue
)
tmp_queue
.
extend
(
last_node_to_reqs
[
cur_node
])
def
get_dfs_priority
(
self
,
cur_node
,
node_to_priority
,
last_node_to_reqs
,
q
):
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
self
.
get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
q
.
extend
(
last_node_to_reqs
[
cur_node
])
python/sglang/srt/server_args.py
View file @
39191c85
...
...
@@ -149,7 +149,8 @@ class ServerArgs:
"--schedule-heuristic"
,
type
=
str
,
default
=
ServerArgs
.
schedule_heuristic
,
help
=
"Schudule mode: [lpm, weight, random, fcfs]"
,
choices
=
[
"lpm"
,
"random"
,
"fcfs"
,
"dfs-weight"
],
help
=
"Scheduling Heuristic."
,
)
parser
.
add_argument
(
"--schedule-conservativeness"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment