Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
56198b45
"vscode:/vscode.git/clone" did not exist on "6dc6486565ea1d8d1be567eefc1094e9185560a1"
Unverified
Commit
56198b45
authored
Dec 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 16, 2024
Browse files
Add a benchmark script for in-batch prefix caching (#2494)
parent
ba36b552
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
177 additions
and
39 deletions
+177
-39
benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py
benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py
+130
-0
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+47
-39
No files found.
benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py
0 → 100644
View file @
56198b45
# Benchmark with lots of common prefixes. Used to benchmark prefix caching performance.
#
# Launch a server:
# python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --log-level-http warning
import
random
import
string
import
time
from
tqdm
import
tqdm
from
transformers
import
AutoTokenizer
import
sglang
as
sgl
from
sglang
import
set_default_backend
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
def
generate_random_string
(
token_length
:
int
)
->
str
:
random_string
=
""
.
join
(
random
.
choices
(
string
.
ascii_letters
+
string
.
digits
,
k
=
token_length
*
100
)
)
tokenized_output
=
tokenizer
.
encode
(
random_string
,
add_special_tokens
=
False
)[
:
token_length
]
if
len
(
tokenized_output
)
<
token_length
:
tokenized_output
=
tokenized_output
+
[
tokenizer
.
pad_token_id
]
*
(
token_length
-
len
(
tokenized_output
)
)
decoded_string
=
tokenizer
.
decode
(
tokenized_output
,
skip_special_tokens
=
False
)
return
decoded_string
def
generate_unique_prefix
(
base_text
,
index
):
return
str
(
index
)
+
base_text
[
len
(
str
(
index
))
:]
@
sgl
.
function
def
text_qa
(
s
,
question
,
gen_len
):
s
+=
"Q: "
+
question
+
"
\n
"
s
+=
"A:"
+
sgl
.
gen
(
"answer"
,
stop
=
"
\n
"
,
temperature
=
0
,
max_tokens
=
gen_len
)
def
prepare_prompts
(
num_prefix
,
num_samples_per_prefix
,
prefix_length
,
suffix_length
):
base_prefix
=
generate_random_string
(
prefix_length
)
tot_input_len
=
0
all_prompts
=
[]
for
i
in
tqdm
(
range
(
num_prefix
),
desc
=
"prepare prompts"
):
unique_prefix
=
generate_unique_prefix
(
base_prefix
,
i
)
prompt_list
=
[]
for
j
in
range
(
num_samples_per_prefix
):
suffix
=
generate_random_string
(
suffix_length
)
prompt
=
unique_prefix
+
suffix
prompt_list
.
append
(
prompt
)
tot_input_len
+=
len
(
tokenizer
.
encode
(
prompt
))
all_prompts
.
append
(
prompt_list
)
return
all_prompts
,
tot_input_len
def
test_batch_by_batch
(
all_prompts
,
gen_len
):
backend
.
flush_cache
()
tot_time
=
0
for
i
in
range
(
len
(
all_prompts
)):
tic
=
time
.
time
()
text_qa
.
run_batch
(
list
(
zip
(
all_prompts
[
i
],
[
gen_len
]
*
len
(
all_prompts
[
i
]))),
)
tot_time
+=
time
.
time
()
-
tic
return
tot_time
def
test_batch_by_batch_with_hint
(
all_prompts
,
gen_len
):
backend
.
flush_cache
()
tot_time
=
0
for
i
in
range
(
len
(
all_prompts
)):
tic
=
time
.
time
()
# Send a hint to cache the prefix
text_qa
.
run_batch
(
list
(
zip
(
all_prompts
[
i
][:
1
],
[
gen_len
])))
# Send the batch
text_qa
.
run_batch
(
list
(
zip
(
all_prompts
[
i
],
[
gen_len
]
*
len
(
all_prompts
[
i
]))))
tot_time
+=
time
.
time
()
-
tic
return
tot_time
def
test_send_all
(
all_prompts
,
gen_len
):
backend
.
flush_cache
()
all_prompts
=
[
x
for
prompt_list
in
all_prompts
for
x
in
prompt_list
]
tic
=
time
.
time
()
text_qa
.
run_batch
(
list
(
zip
(
all_prompts
,
[
gen_len
]
*
len
(
all_prompts
))),
)
tot_time
=
time
.
time
()
-
tic
return
tot_time
if
__name__
==
"__main__"
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/llama-tokenizer"
)
backend
=
RuntimeEndpoint
(
"http://127.0.0.1:30000"
)
set_default_backend
(
backend
)
random
.
seed
(
0
)
num_prefix
=
10
num_samples_per_prefix
=
32
prefix_length
=
1024
suffix_length
=
128
gen_len
=
1
all_prompts
,
tot_input_len
=
prepare_prompts
(
num_prefix
,
num_samples_per_prefix
,
prefix_length
,
suffix_length
)
print
(
f
"Total input token length:
{
tot_input_len
}
\n
"
)
cost
=
test_batch_by_batch
(
all_prompts
,
gen_len
)
print
(
f
"Latency of test_batch_by_batch :
{
cost
:.
4
f
}
s
\n
"
)
cost
=
test_batch_by_batch_with_hint
(
all_prompts
,
gen_len
)
print
(
f
"Latency of test_batch_by_batch_with_hint:
{
cost
:.
4
f
}
s
\n
"
)
cost
=
test_send_all
(
all_prompts
,
gen_len
)
print
(
f
"Latency of test_send_all :
{
cost
:.
4
f
}
s
\n
"
)
python/sglang/srt/managers/schedule_policy.py
View file @
56198b45
...
@@ -34,11 +34,19 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
...
@@ -34,11 +34,19 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION"
,
"4096"
)
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION"
,
"4096"
)
)
)
# The threshold to apply in-batch prefix caching.
# Threshold for in-batch prefix cache.
# If we use too small value, in-batch prefix caching cannot be used. E.g.,
# If a request has a matched prefix length (against existing cache) less than this value,
# imagine "the" prefix.
# the scheduler runs the in-batch prefix caching check for this request.
IN_BATCH_PREFIX_CACHING_THRESHOLD
=
int
(
# If we set it to -1, it means we disable in-batch prefix caching.
os
.
environ
.
get
(
"SGLANG_IN_BATCH_PREFIX_CACHING_THRESHOLD"
,
"32"
)
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD
=
int
(
os
.
environ
.
get
(
"IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD"
,
"32"
)
)
# Threshold for in-batch prefix cache.
# If a request has a matched prefix length (within the waiting queue) larger than this value,
# the scheduler deprioritizes this request
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
=
int
(
os
.
environ
.
get
(
"IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD"
,
"32"
)
)
)
...
@@ -51,6 +59,11 @@ class SchedulePolicy:
...
@@ -51,6 +59,11 @@ class SchedulePolicy:
self
.
policy
=
policy
self
.
policy
=
policy
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
# It is used to find the matching prefix for in-batch prefix caching.
self
.
waiting_queue_radix_tree
=
RadixCache
(
req_to_token_pool
=
None
,
token_to_kv_pool
=
None
,
disable
=
False
)
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
if
len
(
waiting_queue
)
>
128
and
self
.
policy
==
"lpm"
:
if
len
(
waiting_queue
)
>
128
and
self
.
policy
==
"lpm"
:
# Turn off the expensive prefix matching and sorting when the #queue is large.
# Turn off the expensive prefix matching and sorting when the #queue is large.
...
@@ -60,50 +73,54 @@ class SchedulePolicy:
...
@@ -60,50 +73,54 @@ class SchedulePolicy:
# Compute matched prefix length
# Compute matched prefix length
prefix_computed
=
False
prefix_computed
=
False
# rid to deprioritize in the current run.
temporary_deprioritized
=
{}
if
policy
==
"lpm"
or
policy
==
"dfs-weight"
:
if
policy
==
"lpm"
or
policy
==
"dfs-weight"
:
# It is used to find the matching prefix for in-batch prefix caching.
# rid to deprioritize in the current run for in-batch prefix caching.
temp_radix
=
RadixCache
(
None
,
None
,
False
)
temporary_deprioritized
=
set
()
self
.
waiting_queue_radix_tree
.
reset
()
for
r
in
waiting_queue
:
for
r
in
waiting_queue
:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
prefix_ids
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
rid
=
r
.
rid
,
key
=
prefix_ids
)
)
# NOTE(sang): This logic is for
I
n-batch prefix caching;
# NOTE(sang): This logic is for
i
n-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_THRESHOLD > 0 because too small
# We prefer to set IN_BATCH_PREFIX_CACHING_
CHECK_
THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine "the").
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if
len
(
r
.
prefix_indices
)
<=
IN_BATCH_PREFIX_CACHING_THRESHOLD
:
if
len
(
r
.
prefix_indices
)
<=
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD
:
in_batch_matching_prefixes
,
_
=
temp_radix
.
match_prefix
(
in_batch_matching_prefixes
,
_
=
(
self
.
waiting_queue_radix_tree
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
rid
=
r
.
rid
,
key
=
prefix_ids
)
)
)
if
(
if
(
len
(
in_batch_matching_prefixes
)
len
(
in_batch_matching_prefixes
)
>=
IN_BATCH_PREFIX_CACHING_THRESHOLD
>=
IN_BATCH_PREFIX_CACHING_
DEPRIORITIZE_
THRESHOLD
):
):
temporary_deprioritized
[
r
.
rid
]
=
r
temporary_deprioritized
.
add
(
r
.
rid
)
else
:
else
:
temp_radix
.
insert
(
prefix_ids
,
torch
.
tensor
(
prefix_ids
))
# Insert with a dummy key
self
.
waiting_queue_radix_tree
.
insert
(
prefix_ids
,
torch
.
empty
(
len
(
prefix_ids
),
dtype
=
torch
.
bool
)
)
prefix_computed
=
True
prefix_computed
=
True
if
policy
==
"lpm"
:
if
policy
==
"lpm"
:
# Longest Prefix Match
# Longest Prefix Match
def
get_priority
(
r
:
Req
):
waiting_queue
.
sort
(
score
=
0
key
=
lambda
r
:
(
if
r
.
rid
in
temporary_deprioritized
:
-
len
(
r
.
prefix_indices
)
score
=
float
(
"inf"
)
if
r
.
rid
not
in
temporary_deprioritized
else
:
else
float
(
"inf"
)
score
=
-
len
(
r
.
prefix_indices
)
)
return
score
)
waiting_queue
.
sort
(
key
=
get_priority
)
elif
policy
==
"fcfs"
:
elif
policy
==
"fcfs"
:
# first come first serve
# first come first serve
pass
pass
...
@@ -113,11 +130,11 @@ class SchedulePolicy:
...
@@ -113,11 +130,11 @@ class SchedulePolicy:
elif
policy
==
"random"
:
elif
policy
==
"random"
:
random
.
shuffle
(
waiting_queue
)
random
.
shuffle
(
waiting_queue
)
elif
policy
==
"dfs-weight"
:
elif
policy
==
"dfs-weight"
:
# Experimental policy based on custom weights
last_node_to_reqs
=
defaultdict
(
list
)
last_node_to_reqs
=
defaultdict
(
list
)
for
req
in
waiting_queue
:
for
req
in
waiting_queue
:
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
last_node_to_reqs
[
req
.
last_node
].
append
(
req
)
# node -> # of requests for that node.
node_to_weight
=
defaultdict
(
int
)
node_to_weight
=
defaultdict
(
int
)
for
node
in
last_node_to_reqs
:
for
node
in
last_node_to_reqs
:
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
node_to_weight
[
node
]
=
len
(
last_node_to_reqs
[
node
])
...
@@ -129,9 +146,7 @@ class SchedulePolicy:
...
@@ -129,9 +146,7 @@ class SchedulePolicy:
node_to_weight
,
node_to_weight
,
last_node_to_reqs
,
last_node_to_reqs
,
waiting_queue
,
waiting_queue
,
temporary_deprioritized
,
)
)
waiting_queue
.
extend
(
temporary_deprioritized
.
values
())
else
:
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
raise
ValueError
(
f
"Unknown schedule_policy:
{
policy
=
}
"
)
...
@@ -148,19 +163,12 @@ class SchedulePolicy:
...
@@ -148,19 +163,12 @@ class SchedulePolicy:
node_to_priority
:
Dict
[
TreeNode
,
int
],
node_to_priority
:
Dict
[
TreeNode
,
int
],
last_node_to_reqs
:
Dict
[
TreeNode
,
List
[
Req
]],
last_node_to_reqs
:
Dict
[
TreeNode
,
List
[
Req
]],
q
:
List
,
q
:
List
,
temporary_deprioritized
:
Dict
[
str
,
Req
],
):
):
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
for
child
in
childs
:
self
.
get_dfs_priority
(
self
.
get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
child
,
node_to_priority
,
last_node_to_reqs
,
q
,
temporary_deprioritized
q
.
extend
(
last_node_to_reqs
[
cur_node
])
)
for
req
in
last_node_to_reqs
[
cur_node
]:
if
req
.
rid
in
temporary_deprioritized
:
continue
q
.
append
(
req
)
class
AddReqResult
(
Enum
):
class
AddReqResult
(
Enum
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment