Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
39191c85
Unverified
Commit
39191c85
authored
May 13, 2024
by
Liangsheng Yin
Committed by
GitHub
May 13, 2024
Browse files
Cache optimizations (#418)
parent
562b8857
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
117 additions
and
96 deletions
+117
-96
python/sglang/global_config.py
python/sglang/global_config.py
+3
-0
python/sglang/srt/backend_config.py
python/sglang/srt/backend_config.py
+0
-13
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+17
-19
python/sglang/srt/managers/router/manager.py
python/sglang/srt/managers/router/manager.py
+4
-4
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+27
-14
python/sglang/srt/managers/router/radix_cache.py
python/sglang/srt/managers/router/radix_cache.py
+47
-17
python/sglang/srt/managers/router/scheduler.py
python/sglang/srt/managers/router/scheduler.py
+17
-28
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-1
No files found.
python/sglang/global_config.py
View file @
39191c85
...
...
@@ -25,5 +25,8 @@ class GlobalConfig:
# adjust_cache: Adjust the position embedding of KV cache.
self
.
concate_and_append_mode
=
"no_adjust"
# Request dependency time due to network delay
self
.
request_dependency_time
=
0.03
global_config
=
GlobalConfig
()
python/sglang/srt/backend_config.py
deleted
100644 → 0
View file @
562b8857
"""
Backend configurations, may vary with different serving platforms.
"""
from
dataclasses
import
dataclass
@
dataclass
class
BackendConfig
:
extend_dependency_time
:
float
=
0.03
GLOBAL_BACKEND_CONFIG
=
BackendConfig
()
python/sglang/srt/managers/router/infer_batch.py
View file @
39191c85
...
...
@@ -335,20 +335,20 @@ class Batch:
req
=
self
.
reqs
[
idx
]
retracted_reqs
.
append
(
req
)
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
# TODO: apply more fine-grained retraction
last_uncached_pos
=
len
(
req
.
prefix_indices
)
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
idx
]
][
last_uncached_pos
:
seq_lens_cpu
[
idx
]]
self
.
token_to_kv_pool
.
dec_refs
(
token_indices
)
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
req
.
prefix_indices
=
None
req
.
last_node
=
None
req
.
extend_input_len
=
0
req
.
output_ids
=
[]
req
.
regex_fsm_state
=
0
# TODO: apply more fine-grained retraction
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
idx
]
][:
seq_lens_cpu
[
idx
]]
self
.
token_to_kv_pool
.
dec_refs
(
token_indices
)
self
.
filter_batch
(
sorted_indices
)
return
retracted_reqs
...
...
@@ -367,20 +367,18 @@ class Batch:
if
len
(
jump_forward_str
)
<=
1
:
continue
# insert the old request into tree_cache
token_ids_in_memory
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
if
req_pool_indices_cpu
is
None
:
req_pool_indices_cpu
=
self
.
req_pool_indices
.
tolist
()
req_pool_idx
=
req_pool_indices_cpu
[
i
]
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids_in_memory
)
]
prefix_len
=
self
.
tree_cache
.
insert
(
token_ids_in_memory
,
indices
.
clone
()
# insert the old request into tree_cache
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
)
self
.
token_to_kv_pool
.
dec_refs
(
indices
[:
prefix_len
])
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
tree_cache
.
dec_
ref_counter
(
req
.
last_node
)
# unlock the last node
self
.
tree_cache
.
dec_
lock_ref
(
req
.
last_node
)
# jump-forward
req
.
jump_forward_and_retokenize
(
jump_forward_str
,
next_state
)
...
...
python/sglang/srt/managers/router/manager.py
View file @
39191c85
...
...
@@ -5,7 +5,7 @@ import uvloop
import
zmq
import
zmq.asyncio
from
sglang
.srt.backend_config
import
GLOBAL_BACKEND_CONFIG
from
sglang
import
global_config
from
sglang.srt.managers.router.model_rpc
import
ModelRpcClient
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
get_exception_traceback
...
...
@@ -30,7 +30,7 @@ class RouterManager:
self
.
recv_reqs
=
[]
# Init some configs
self
.
extend
_dependency_time
=
GLOBAL_BACKEND_CONFIG
.
extend
_dependency_time
self
.
request
_dependency_time
=
global_config
.
request
_dependency_time
async
def
loop_for_forward
(
self
):
while
True
:
...
...
@@ -46,9 +46,9 @@ class RouterManager:
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished
for
obj
in
out_pyobjs
])
if
has_finished
:
if
self
.
extend
_dependency_time
>
0
:
if
self
.
request
_dependency_time
>
0
:
slept
=
True
await
asyncio
.
sleep
(
self
.
extend
_dependency_time
)
await
asyncio
.
sleep
(
self
.
request
_dependency_time
)
if
not
slept
:
await
asyncio
.
sleep
(
0.0006
)
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
39191c85
...
...
@@ -117,7 +117,11 @@ class ModelRpcServer:
logger
.
info
(
f
"server_args:
{
server_args
.
print_mode_args
()
}
"
)
# Init cache
self
.
tree_cache
=
RadixCache
(
disable
=
server_args
.
disable_radix_cache
)
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
disable
=
server_args
.
disable_radix_cache
,
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
Scheduler
(
self
.
schedule_heuristic
,
...
...
@@ -203,6 +207,8 @@ class ModelRpcServer:
# Run new fill batch
self
.
forward_fill_batch
(
new_batch
)
self
.
cache_filled_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
...
...
@@ -349,20 +355,19 @@ class ModelRpcServer:
and
req
.
extend_input_len
+
new_batch_input_tokens
<
self
.
max_prefill_num_token
):
delta
=
self
.
tree_cache
.
inc_
ref_counter
(
req
.
last_node
)
delta
=
self
.
tree_cache
.
inc_
lock_ref
(
req
.
last_node
)
available_size
+=
delta
if
not
(
req
.
extend_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
):
# Undo
the insertion
delta
=
self
.
tree_cache
.
dec_
ref_counter
(
req
.
last_node
)
# Undo
locking
delta
=
self
.
tree_cache
.
dec_
lock_ref
(
req
.
last_node
)
available_size
+=
delta
break
else
:
# Add this request to the running batch
self
.
token_to_kv_pool
.
add_refs
(
req
.
prefix_indices
)
can_run_list
.
append
(
req
)
new_batch_total_tokens
+=
(
req
.
extend_input_len
+
req
.
max_new_tokens
()
...
...
@@ -477,6 +482,18 @@ class ModelRpcServer:
self
.
handle_finished_requests
(
batch
)
def
cache_filled_batch
(
self
,
batch
:
Batch
):
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
tolist
()
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
del_in_memory_pool
=
False
,
old_last_node
=
req
.
last_node
,
)
req
.
prefix_indices
,
req
.
last_node
=
new_prefix_indices
,
new_last_node
def
forward_decode_batch
(
self
,
batch
:
Batch
):
# check if decode out of memory
if
not
batch
.
check_decode_mem
():
...
...
@@ -636,17 +653,13 @@ class ModelRpcServer:
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
tolist
()
for
i
in
finished_indices
:
req
=
batch
.
reqs
[
i
]
req_pool_idx
=
req_pool_indices_cpu
[
i
]
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)
seq_len
=
len
(
token_ids
)
-
1
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
seq_len
]
prefix_len
=
self
.
tree_cache
.
insert
(
token_ids
[:
seq_len
],
indices
.
clone
()
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
)
self
.
token_to_kv_pool
.
dec_refs
(
indices
[:
prefix_len
])
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
# Update batch tensors
if
unfinished_indices
:
...
...
python/sglang/srt/managers/router/radix_cache.py
View file @
39191c85
...
...
@@ -11,7 +11,7 @@ class TreeNode:
self
.
parent
=
None
self
.
key
=
None
self
.
value
=
None
self
.
ref_counter
=
0
self
.
lock_ref
=
0
self
.
last_access_time
=
time
.
time
()
def
__lt__
(
self
,
other
:
"TreeNode"
):
...
...
@@ -28,7 +28,9 @@ def _key_match(key0, key1):
class
RadixCache
:
def
__init__
(
self
,
disable
:
bool
=
False
):
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
,
disable
:
bool
=
False
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
disable
=
disable
self
.
reset
()
...
...
@@ -38,7 +40,7 @@ class RadixCache:
self
.
root_node
=
TreeNode
()
self
.
root_node
.
key
=
[]
self
.
root_node
.
value
=
[]
self
.
root_node
.
ref_counter
=
1
self
.
root_node
.
lock_ref
=
1
self
.
evictable_size_
=
0
def
match_prefix
(
self
,
key
):
...
...
@@ -50,6 +52,8 @@ class RadixCache:
self
.
_match_prefix_helper
(
self
.
root_node
,
key
,
value
,
last_node
)
if
value
:
value
=
torch
.
concat
(
value
)
else
:
value
=
torch
.
tensor
([],
dtype
=
torch
.
int64
)
return
value
,
last_node
[
0
]
def
insert
(
self
,
key
,
value
=
None
):
...
...
@@ -60,6 +64,34 @@ class RadixCache:
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_req
(
self
,
token_ids
,
last_uncached_pos
,
req_pool_idx
,
del_in_memory_pool
=
True
,
old_last_node
=
None
,
):
# Insert the request into radix cache
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
new_prefix_len
=
self
.
insert
(
token_ids
,
indices
.
clone
())
# Radix Cache takes one ref in memory pool
self
.
token_to_kv_pool
.
dec_refs
(
indices
[
last_uncached_pos
:
new_prefix_len
])
if
del_in_memory_pool
:
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
else
:
cached_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
assert
len
(
cached_indices
)
==
len
(
token_ids
)
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
last_uncached_pos
:
len
(
cached_indices
)
]
=
cached_indices
[
last_uncached_pos
:]
self
.
dec_lock_ref
(
old_last_node
)
self
.
inc_lock_ref
(
new_last_node
)
return
cached_indices
,
new_last_node
def
pretty_print
(
self
):
self
.
_print_helper
(
self
.
root_node
,
0
)
print
(
f
"#tokens:
{
self
.
total_size
()
}
"
)
...
...
@@ -80,7 +112,7 @@ class RadixCache:
if
x
==
self
.
root_node
:
break
if
x
.
ref_counter
>
0
:
if
x
.
lock_ref
>
0
:
continue
num_evicted
+=
evict_callback
(
x
.
value
)
...
...
@@ -89,23 +121,23 @@ class RadixCache:
if
len
(
x
.
parent
.
children
)
==
0
:
heapq
.
heappush
(
leaves
,
x
.
parent
)
def
inc_
ref_counter
(
self
,
n
ode
):
def
inc_
lock_ref
(
self
,
node
:
TreeN
ode
):
delta
=
0
while
node
!=
self
.
root_node
:
if
node
.
ref_counter
==
0
:
if
node
.
lock_ref
==
0
:
self
.
evictable_size_
-=
len
(
node
.
value
)
delta
-=
len
(
node
.
value
)
node
.
ref_counter
+=
1
node
.
lock_ref
+=
1
node
=
node
.
parent
return
delta
def
dec_
ref_counter
(
self
,
n
ode
):
def
dec_
lock_ref
(
self
,
node
:
TreeN
ode
):
delta
=
0
while
node
!=
self
.
root_node
:
if
node
.
ref_counter
==
1
:
if
node
.
lock_ref
==
1
:
self
.
evictable_size_
+=
len
(
node
.
value
)
delta
+=
len
(
node
.
value
)
node
.
ref_counter
-=
1
node
.
lock_ref
-=
1
node
=
node
.
parent
return
delta
...
...
@@ -131,12 +163,12 @@ class RadixCache:
last_node
[
0
]
=
child
self
.
_match_prefix_helper
(
child
,
key
[
prefix_len
:],
value
,
last_node
)
def
_split_node
(
self
,
key
,
child
,
split_len
):
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
):
# new_node -> child
new_node
=
TreeNode
()
new_node
.
children
=
{
key
[
split_len
:][
0
]:
child
}
new_node
.
parent
=
child
.
parent
new_node
.
ref_counter
=
child
.
ref_counter
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
key
=
child
.
key
[:
split_len
]
new_node
.
value
=
child
.
value
[:
split_len
]
child
.
parent
=
new_node
...
...
@@ -176,11 +208,9 @@ class RadixCache:
self
.
evictable_size_
+=
len
(
value
)
return
0
def
_print_helper
(
self
,
node
,
indent
):
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
):
for
_
,
child
in
node
.
children
.
items
():
print
(
" "
*
indent
,
len
(
child
.
key
),
child
.
key
[:
10
],
f
"r=
{
child
.
ref_counter
}
"
)
print
(
" "
*
indent
,
len
(
child
.
key
),
child
.
key
[:
10
],
f
"r=
{
child
.
lock_ref
}
"
)
self
.
_print_helper
(
child
,
indent
=
indent
+
2
)
def
_delete_leaf
(
self
,
node
):
...
...
@@ -211,7 +241,7 @@ class RadixCache:
if
__name__
==
"__main__"
:
tree
=
RadixCache
()
tree
=
RadixCache
(
None
,
None
,
False
)
tree
.
insert
(
"Hello"
)
tree
.
insert
(
"Hello"
)
...
...
python/sglang/srt/managers/router/scheduler.py
View file @
39191c85
...
...
@@ -27,44 +27,33 @@ class Scheduler:
return
forward_queue
elif
self
.
schedule_heuristic
==
"fcfs"
:
return
forward_queue
elif
self
.
schedule_heuristic
==
"weight"
:
elif
self
.
schedule_heuristic
==
"
dfs-
weight"
:
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
forward_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
for
node
in
last_node_to_reqs
:
last_node_to_reqs
[
node
].
sort
(
key
=
lambda
x
:
-
len
(
x
.
prefix_indices
))
node_to_weight
=
defaultdict
(
int
)
self
.
_calc_weight_recursive
(
self
.
tree_cache
.
root_node
,
last_node_to_reqs
,
node
_to_weight
)
for
node
in
last_node_to_reqs
:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
self
.
calc_weight
(
self
.
tree_cache
.
root_node
,
node_to_weight
)
tmp_queue
=
[]
self
.
_
get_
weight
_priority
_recursive
(
self
.
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
tmp_queue
q
=
[]
self
.
get_
dfs
_priority
(
self
.
tree_cache
.
root_node
,
node_to_weight
,
last_node_to_reqs
,
q
)
assert
len
(
tmp_queue
)
==
len
(
forward_queue
)
return
tmp_queue
assert
len
(
q
)
==
len
(
forward_queue
)
return
q
else
:
raise
ValueError
(
f
"Unknown schedule_heuristic:
{
self
.
schedule_heuristic
}
"
)
def
_calc_weight_recursive
(
self
,
cur_node
,
last_node_to_reqs
,
node_to_weight
):
node_to_weight
[
cur_node
]
=
1
if
cur_node
in
last_node_to_reqs
:
node_to_weight
[
cur_node
]
+=
len
(
last_node_to_reqs
[
cur_node
])
def
calc_weight
(
self
,
cur_node
,
node_to_weight
):
for
child
in
cur_node
.
children
.
values
():
self
.
_
calc_weight
_recursive
(
child
,
last_node_to_reqs
,
node_to_weight
)
self
.
calc_weight
(
child
,
node_to_weight
)
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
def
_get_weight_priority_recursive
(
self
,
cur_node
,
node_to_wight
,
last_node_to_reqs
,
tmp_queue
):
visit_list
=
[
child
for
child
in
cur_node
.
children
.
values
()]
visit_list
.
sort
(
key
=
lambda
x
:
-
node_to_wight
[
x
])
# for node in visit_list:
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
for
child
in
visit_list
:
self
.
_get_weight_priority_recursive
(
child
,
node_to_wight
,
last_node_to_reqs
,
tmp_queue
)
tmp_queue
.
extend
(
last_node_to_reqs
[
cur_node
])
def
get_dfs_priority
(
self
,
cur_node
,
node_to_priority
,
last_node_to_reqs
,
q
):
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
self
.
get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
q
.
extend
(
last_node_to_reqs
[
cur_node
])
python/sglang/srt/server_args.py
View file @
39191c85
...
...
@@ -149,7 +149,8 @@ class ServerArgs:
"--schedule-heuristic"
,
type
=
str
,
default
=
ServerArgs
.
schedule_heuristic
,
help
=
"Schudule mode: [lpm, weight, random, fcfs]"
,
choices
=
[
"lpm"
,
"random"
,
"fcfs"
,
"dfs-weight"
],
help
=
"Scheduling Heuristic."
,
)
parser
.
add_argument
(
"--schedule-conservativeness"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment