Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
90a6c759
Unverified
Commit
90a6c759
authored
Nov 18, 2024
by
Ricky Xu
Committed by
GitHub
Nov 18, 2024
Browse files
[misc] partial prefix & random input generation benchmark (#9929)
Signed-off-by:
rickyx
<
rickyx@anyscale.com
>
parent
2298e69b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
91 additions
and
25 deletions
+91
-25
benchmarks/benchmark_prefix_caching.py
benchmarks/benchmark_prefix_caching.py
+91
-25
No files found.
benchmarks/benchmark_prefix_caching.py
View file @
90a6c759
...
...
@@ -54,13 +54,30 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
print
(
f
"cost time
{
end_time
-
start_time
}
"
)
def
sample_requests
(
@
dataclasses
.
dataclass
class
Request
:
prompt
:
str
prompt_len
:
int
output_len
:
int
def
sample_tokens
(
tokenizer
:
PreTrainedTokenizerBase
,
length
:
int
)
->
str
:
vocab
=
tokenizer
.
get_vocab
()
# Remove the special tokens.
vocab
=
{
k
:
v
for
k
,
v
in
vocab
.
items
()
if
k
not
in
tokenizer
.
all_special_ids
}
return
random
.
choices
(
list
(
vocab
.
values
()),
k
=
length
)
def
sample_requests_from_dataset
(
dataset_path
:
str
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
input_length_range
:
Tuple
[
int
,
int
],
fixed_output_len
:
Optional
[
int
],
)
->
List
[
Tuple
[
str
,
int
,
int
]
]:
)
->
List
[
Request
]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
...
...
@@ -77,31 +94,55 @@ def sample_requests(
random
.
shuffle
(
dataset
)
min_len
,
max_len
=
input_length_range
assert
min_len
>=
0
and
max_len
>=
min_len
,
"input_length_range too small"
# Filter out sequences that are too long or too short
filtered_dataset
:
List
[
Tuple
[
str
,
int
,
int
]]
=
[]
filtered_requests
:
List
[
Request
]
=
[]
for
i
in
range
(
len
(
dataset
)):
if
len
(
filtered_
dataset
)
==
num_requests
:
if
len
(
filtered_
requests
)
==
num_requests
:
break
# Tokenize the prompts and completions.
prompt
=
dataset
[
i
][
0
]
prompt
_token_ids
=
tokenizer
(
prompt
).
input
_ids
prompt
_token_ids
=
tokenizer
(
dataset
[
i
][
0
]
).
input_ids
prompt
=
tokenizer
.
decode
(
prompt_token
_ids
)
completion
=
dataset
[
i
][
1
]
completion_token_ids
=
tokenizer
(
completion
).
input_ids
prompt_len
=
len
(
prompt_token_ids
)
output_len
=
len
(
completion_token_ids
)
if
fixed_output_len
is
None
else
fixed_output_len
if
prompt_len
<
4
or
output_len
<
4
:
# Prune too short sequences.
continue
output_len
=
(
len
(
completion_token_ids
)
if
fixed_output_len
is
None
else
fixed_output_len
)
if
min_len
<=
prompt_len
<=
max_len
:
filtered_dataset
.
append
((
prompt
,
prompt_len
,
output_len
))
filtered_requests
.
append
(
Request
(
prompt
,
prompt_len
,
output_len
))
return
filtered_requests
def
sample_requests_from_random
(
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
input_length_range
:
Tuple
[
int
,
int
],
fixed_output_len
:
Optional
[
int
],
prefix_len
:
int
,
)
->
List
[
Request
]:
return
filtered_dataset
requests
=
[]
prefix_token_ids
=
sample_tokens
(
tokenizer
,
prefix_len
)
min_len
,
max_len
=
input_length_range
for
i
in
range
(
num_requests
):
unique_part_token_ids
=
sample_tokens
(
tokenizer
,
random
.
randint
(
min_len
-
prefix_len
,
max_len
-
prefix_len
))
prompt_token_ids
=
prefix_token_ids
+
unique_part_token_ids
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
prompt_len
=
len
(
prompt_token_ids
)
assert
(
min_len
<=
prompt_len
<=
max_len
),
f
"prompt_len
{
prompt_len
}
out of range
{
min_len
}
:
{
max_len
}
"
requests
.
append
(
Request
(
prompt
,
prompt_len
,
fixed_output_len
))
return
requests
def
repeat_and_sort_requests
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]
],
def
repeat_and_sort_requests
(
requests
:
List
[
Request
],
repeat_count
:
int
,
sort
:
bool
=
False
)
->
List
[
str
]:
repeated_requests
=
requests
*
repeat_count
...
...
@@ -109,7 +150,7 @@ def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
repeated_requests
.
sort
(
key
=
lambda
x
:
x
[
1
])
else
:
random
.
shuffle
(
repeated_requests
)
return
[
req
[
0
]
for
req
in
repeated_requests
]
return
[
req
.
prompt
for
req
in
repeated_requests
]
def
main
(
args
):
...
...
@@ -117,9 +158,12 @@ def main(args):
input_length_range
=
tuple
(
map
(
int
,
args
.
input_length_range
.
split
(
':'
)))
random
.
seed
(
args
.
seed
)
if
args
.
dataset_path
is
not
None
:
print
(
f
"Start to sample
{
args
.
num_prompts
}
prompts"
if
args
.
prefix_len
>
0
:
raise
ValueError
(
"prefix-len is not supported when "
"dataset-path is provided."
)
print
(
f
"Start to sample
{
args
.
num_prompts
}
prompts "
f
"from
{
args
.
dataset_path
}
"
)
filtered_
datase
ts
=
sample_requests
(
filtered_
reques
ts
=
sample_requests
_from_dataset
(
dataset_path
=
args
.
dataset_path
,
num_requests
=
args
.
num_prompts
,
tokenizer
=
tokenizer
,
...
...
@@ -127,9 +171,22 @@ def main(args):
fixed_output_len
=
args
.
output_len
,
)
else
:
prompt_len
=
len
(
tokenizer
(
PROMPT
).
input_ids
)
filtered_datasets
=
[(
PROMPT
,
prompt_len
,
args
.
output_len
)
]
*
args
.
num_prompts
print
(
f
"Start to sample
{
args
.
num_prompts
}
prompts from random"
)
filtered_requests
=
sample_requests_from_random
(
num_requests
=
args
.
num_prompts
,
tokenizer
=
tokenizer
,
input_length_range
=
input_length_range
,
fixed_output_len
=
args
.
output_len
,
prefix_len
=
args
.
prefix_len
,
)
# Print some helpful stats of the requests.
print
(
f
"Sampled
{
len
(
filtered_requests
)
}
requests."
)
prompt_lens
=
[
req
.
prompt_len
for
req
in
filtered_requests
]
print
(
f
"Average input length:
{
sum
(
prompt_lens
)
/
len
(
prompt_lens
)
}
"
)
print
(
f
"P50 input length:
{
sorted
(
prompt_lens
)[
len
(
prompt_lens
)
//
2
]
}
"
)
print
(
f
"Min Prompt Length:
{
min
(
prompt_lens
)
}
"
)
print
(
f
"Max Prompt Length:
{
max
(
prompt_lens
)
}
"
)
engine_args
=
EngineArgs
.
from_cli_args
(
args
)
...
...
@@ -137,8 +194,8 @@ def main(args):
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
args
.
output_len
)
print
(
"Testing filtered
datase
ts"
)
prompts
=
repeat_and_sort_requests
(
filtered_
datase
ts
,
print
(
"Testing filtered
reques
ts"
)
prompts
=
repeat_and_sort_requests
(
filtered_
reques
ts
,
repeat_count
=
args
.
repeat_count
,
sort
=
args
.
sort
)
...
...
@@ -161,20 +218,29 @@ if __name__ == "__main__":
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
'--num-prompts'
,
type
=
int
,
default
=
1
,
required
=
True
,
help
=
"Number of the prompts sampled from dataset"
)
parser
.
add_argument
(
'--repeat-count'
,
type
=
int
,
default
=
1
00
,
default
=
1
,
help
=
'Number of times to repeat each prompt'
)
parser
.
add_argument
(
'--sort'
,
action
=
'store_true'
,
help
=
'Sort prompts by input length'
)
parser
.
add_argument
(
'--input-length-range'
,
type
=
str
,
default
=
'128:256'
,
required
=
True
,
help
=
'Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").'
)
parser
.
add_argument
(
"--prefix-len"
,
type
=
int
,
default
=
0
,
help
=
"Specifies the length of a common prefix to be "
"added to the input prompt. The input-length-range will "
"subtract this length when filtering prompts. Only used "
"when dataset-path is not provided."
,
)
parser
=
EngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment