Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a904ea78
Unverified
Commit
a904ea78
authored
Sep 17, 2025
by
Simon Mo
Committed by
GitHub
Sep 17, 2025
Browse files
[benchmark] add peak throughput metrics and plot (#23867)
Signed-off-by:
simon-mo
<
simon.mo@hey.com
>
parent
b7433ca1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
134 additions
and
69 deletions
+134
-69
vllm/benchmarks/lib/endpoint_request_func.py
vllm/benchmarks/lib/endpoint_request_func.py
+5
-0
vllm/benchmarks/serve.py
vllm/benchmarks/serve.py
+129
-69
No files found.
vllm/benchmarks/lib/endpoint_request_func.py
View file @
a904ea78
...
@@ -89,6 +89,7 @@ class RequestFuncOutput:
...
@@ -89,6 +89,7 @@ class RequestFuncOutput:
tpot
:
float
=
0.0
# avg next-token latencies
tpot
:
float
=
0.0
# avg next-token latencies
prompt_len
:
int
=
0
prompt_len
:
int
=
0
error
:
str
=
""
error
:
str
=
""
start_time
:
float
=
0.0
async
def
async_request_openai_completions
(
async
def
async_request_openai_completions
(
...
@@ -140,6 +141,7 @@ async def async_request_openai_completions(
...
@@ -140,6 +141,7 @@ async def async_request_openai_completions(
generated_text
=
""
generated_text
=
""
st
=
time
.
perf_counter
()
st
=
time
.
perf_counter
()
output
.
start_time
=
st
most_recent_timestamp
=
st
most_recent_timestamp
=
st
try
:
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
...
@@ -272,6 +274,7 @@ async def async_request_openai_chat_completions(
...
@@ -272,6 +274,7 @@ async def async_request_openai_chat_completions(
generated_text
=
""
generated_text
=
""
ttft
=
0.0
ttft
=
0.0
st
=
time
.
perf_counter
()
st
=
time
.
perf_counter
()
output
.
start_time
=
st
most_recent_timestamp
=
st
most_recent_timestamp
=
st
try
:
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
...
@@ -396,6 +399,7 @@ async def async_request_openai_audio(
...
@@ -396,6 +399,7 @@ async def async_request_openai_audio(
generated_text
=
""
generated_text
=
""
ttft
=
0.0
ttft
=
0.0
st
=
time
.
perf_counter
()
st
=
time
.
perf_counter
()
output
.
start_time
=
st
most_recent_timestamp
=
st
most_recent_timestamp
=
st
try
:
try
:
async
with
session
.
post
(
url
=
api_url
,
async
with
session
.
post
(
url
=
api_url
,
...
@@ -475,6 +479,7 @@ async def async_request_openai_embeddings(
...
@@ -475,6 +479,7 @@ async def async_request_openai_embeddings(
output
=
RequestFuncOutput
()
output
=
RequestFuncOutput
()
st
=
time
.
perf_counter
()
st
=
time
.
perf_counter
()
output
.
start_time
=
st
try
:
try
:
async
with
session
.
post
(
async
with
session
.
post
(
url
=
api_url
,
url
=
api_url
,
...
...
vllm/benchmarks/serve.py
View file @
a904ea78
...
@@ -18,9 +18,11 @@ On the client side, run:
...
@@ -18,9 +18,11 @@ On the client side, run:
import
argparse
import
argparse
import
asyncio
import
asyncio
import
gc
import
gc
import
importlib.util
import
json
import
json
import
os
import
os
import
random
import
random
import
shutil
import
time
import
time
import
warnings
import
warnings
from
collections.abc
import
AsyncGenerator
,
Iterable
from
collections.abc
import
AsyncGenerator
,
Iterable
...
@@ -46,6 +48,9 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
...
@@ -46,6 +48,9 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MILLISECONDS_TO_SECONDS_CONVERSION
=
1000
MILLISECONDS_TO_SECONDS_CONVERSION
=
1000
TERM_PLOTLIB_AVAILABLE
=
((
importlib
.
util
.
find_spec
(
"termplotlib"
)
is
not
None
)
and
(
shutil
.
which
(
"gnuplot"
)
is
not
None
))
class
TaskType
(
Enum
):
class
TaskType
(
Enum
):
GENERATION
=
"generation"
GENERATION
=
"generation"
...
@@ -80,18 +85,23 @@ class BenchmarkMetrics:
...
@@ -80,18 +85,23 @@ class BenchmarkMetrics:
median_e2el_ms
:
float
median_e2el_ms
:
float
std_e2el_ms
:
float
std_e2el_ms
:
float
percentiles_e2el_ms
:
list
[
tuple
[
float
,
float
]]
percentiles_e2el_ms
:
list
[
tuple
[
float
,
float
]]
# Max output tokens per second and concurrent requests at that peak
max_output_tokens_per_s
:
float
max_concurrent_requests
:
int
@
dataclass
@
dataclass
class
EmbedBenchmarkMetrics
:
class
EmbedBenchmarkMetrics
:
completed
:
int
completed
:
int
total_input
:
int
total_input
:
int
request_throughput
:
float
request_throughput
:
float
total_token_throughput
:
float
total_token_throughput
:
float
mean_e2el_ms
:
float
mean_e2el_ms
:
float
std_e2el_ms
:
float
std_e2el_ms
:
float
median_e2el_ms
:
float
median_e2el_ms
:
float
percentiles_e2el_ms
:
float
percentiles_e2el_ms
:
float
def
_get_current_request_rate
(
def
_get_current_request_rate
(
ramp_up_strategy
:
Optional
[
Literal
[
"linear"
,
"exponential"
]],
ramp_up_strategy
:
Optional
[
Literal
[
"linear"
,
"exponential"
]],
ramp_up_start_rps
:
Optional
[
int
],
ramp_up_start_rps
:
Optional
[
int
],
...
@@ -150,8 +160,8 @@ async def get_request(
...
@@ -150,8 +160,8 @@ async def get_request(
assert
burstiness
>
0
,
(
assert
burstiness
>
0
,
(
f
"A positive burstiness factor is expected, but given
{
burstiness
}
."
)
f
"A positive burstiness factor is expected, but given
{
burstiness
}
."
)
# Convert to list to get length for ramp-up calculations
# Convert to list to get length for ramp-up calculations
if
isinstance
(
input_requests
,
Iterable
)
and
not
isinstance
(
if
isinstance
(
input_requests
,
input_requests
,
list
):
Iterable
)
and
not
isinstance
(
input_requests
,
list
):
input_requests
=
list
(
input_requests
)
input_requests
=
list
(
input_requests
)
total_requests
=
len
(
input_requests
)
total_requests
=
len
(
input_requests
)
...
@@ -161,12 +171,9 @@ async def get_request(
...
@@ -161,12 +171,9 @@ async def get_request(
request_rates
=
[]
request_rates
=
[]
delay_ts
=
[]
delay_ts
=
[]
for
request_index
,
request
in
enumerate
(
input_requests
):
for
request_index
,
request
in
enumerate
(
input_requests
):
current_request_rate
=
_get_current_request_rate
(
ramp_up_strategy
,
current_request_rate
=
_get_current_request_rate
(
ramp_up_start_rps
,
ramp_up_strategy
,
ramp_up_start_rps
,
ramp_up_end_rps
,
ramp_up_end_rps
,
request_index
,
total_requests
,
request_rate
)
request_index
,
total_requests
,
request_rate
)
request_rates
.
append
(
current_request_rate
)
request_rates
.
append
(
current_request_rate
)
if
current_request_rate
==
float
(
"inf"
):
if
current_request_rate
==
float
(
"inf"
):
delay_ts
.
append
(
0
)
delay_ts
.
append
(
0
)
...
@@ -206,10 +213,8 @@ async def get_request(
...
@@ -206,10 +213,8 @@ async def get_request(
def
calculate_metrics_for_embeddings
(
def
calculate_metrics_for_embeddings
(
outputs
:
list
[
RequestFuncOutput
],
outputs
:
list
[
RequestFuncOutput
],
dur_s
:
float
,
dur_s
:
float
,
selected_percentiles
:
list
[
float
])
->
EmbedBenchmarkMetrics
:
selected_percentiles
:
list
[
float
]
)
->
EmbedBenchmarkMetrics
:
"""Calculate the metrics for the embedding requests.
"""Calculate the metrics for the embedding requests.
Args:
Args:
...
@@ -242,10 +247,8 @@ def calculate_metrics_for_embeddings(
...
@@ -242,10 +247,8 @@ def calculate_metrics_for_embeddings(
mean_e2el_ms
=
np
.
mean
(
e2els
or
0
)
*
1000
,
mean_e2el_ms
=
np
.
mean
(
e2els
or
0
)
*
1000
,
std_e2el_ms
=
np
.
std
(
e2els
or
0
)
*
1000
,
std_e2el_ms
=
np
.
std
(
e2els
or
0
)
*
1000
,
median_e2el_ms
=
np
.
median
(
e2els
or
0
)
*
1000
,
median_e2el_ms
=
np
.
median
(
e2els
or
0
)
*
1000
,
percentiles_e2el_ms
=
[
percentiles_e2el_ms
=
[(
p
,
np
.
percentile
(
e2els
or
0
,
p
)
*
1000
)
(
p
,
np
.
percentile
(
e2els
or
0
,
p
)
*
1000
)
for
p
in
selected_percentiles
],
for
p
in
selected_percentiles
],
)
)
return
metrics
return
metrics
...
@@ -336,6 +339,67 @@ def calculate_metrics(
...
@@ -336,6 +339,67 @@ def calculate_metrics(
"All requests failed. This is likely due to a misconfiguration "
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments."
,
"on the benchmark arguments."
,
stacklevel
=
2
)
stacklevel
=
2
)
# Calculate max output tokens per second metric
max_output_tokens_per_s
=
0.0
max_concurrent_requests
=
0
# Find the time range across all successful requests
successful_outputs
=
[
output
for
output
in
outputs
if
output
.
success
]
if
successful_outputs
:
min_start_time
=
min
(
output
.
start_time
for
output
in
successful_outputs
)
max_end_time
=
max
(
output
.
start_time
+
output
.
latency
for
output
in
successful_outputs
)
# Create second buckets (ceiling to ensure we capture all time)
duration_seconds
=
int
(
np
.
ceil
(
max_end_time
-
min_start_time
))
+
1
tokens_per_second
=
np
.
zeros
(
duration_seconds
)
concurrent_requests_per_second
=
np
.
zeros
(
duration_seconds
)
for
i
,
output
in
enumerate
(
successful_outputs
):
# Calculate token generation timestamp using
# start_time, ttft, and itl
token_times
=
[
output
.
start_time
+
output
.
ttft
]
current_time
=
token_times
[
0
]
for
itl_value
in
output
.
itl
:
current_time
+=
itl_value
token_times
.
append
(
current_time
)
# Add tokens to second buckets
for
token_time
in
token_times
:
second_bucket
=
int
(
token_time
-
min_start_time
)
if
0
<=
second_bucket
<
duration_seconds
:
tokens_per_second
[
second_bucket
]
+=
1
# Track concurrent requests for each second this request was active
request_start_second
=
int
(
output
.
start_time
-
min_start_time
)
request_end_second
=
int
((
output
.
start_time
+
output
.
latency
)
-
min_start_time
)
for
second
in
range
(
request_start_second
,
request_end_second
+
1
):
concurrent_requests_per_second
[
second
]
+=
1
# Find the maximum tokens per second and corresponding
# concurrent requests
if
len
(
tokens_per_second
)
>
0
:
max_output_tokens_per_s
=
float
(
np
.
max
(
tokens_per_second
))
max_concurrent_requests
=
int
(
np
.
max
(
concurrent_requests_per_second
))
if
TERM_PLOTLIB_AVAILABLE
:
import
termplotlib
as
tpl
fig
=
tpl
.
figure
()
fig
.
plot
(
np
.
arange
(
len
(
tokens_per_second
)),
tokens_per_second
,
title
=
"Output tokens per second"
)
fig
.
plot
(
np
.
arange
(
len
(
concurrent_requests_per_second
)),
concurrent_requests_per_second
,
title
=
"Concurrent requests per second"
)
fig
.
show
()
else
:
print
(
"tip: install termplotlib and gnuplot to plot the metrics"
)
metrics
=
BenchmarkMetrics
(
metrics
=
BenchmarkMetrics
(
completed
=
completed
,
completed
=
completed
,
total_input
=
total_input
,
total_input
=
total_input
,
...
@@ -365,6 +429,8 @@ def calculate_metrics(
...
@@ -365,6 +429,8 @@ def calculate_metrics(
median_e2el_ms
=
np
.
median
(
e2els
or
0
)
*
1000
,
median_e2el_ms
=
np
.
median
(
e2els
or
0
)
*
1000
,
percentiles_e2el_ms
=
[(
p
,
np
.
percentile
(
e2els
or
0
,
p
)
*
1000
)
percentiles_e2el_ms
=
[(
p
,
np
.
percentile
(
e2els
or
0
,
p
)
*
1000
)
for
p
in
selected_percentiles
],
for
p
in
selected_percentiles
],
max_output_tokens_per_s
=
max_output_tokens_per_s
,
max_concurrent_requests
=
max_concurrent_requests
,
)
)
return
metrics
,
actual_output_lens
return
metrics
,
actual_output_lens
...
@@ -396,11 +462,8 @@ async def benchmark(
...
@@ -396,11 +462,8 @@ async def benchmark(
ramp_up_end_rps
:
Optional
[
int
]
=
None
,
ramp_up_end_rps
:
Optional
[
int
]
=
None
,
ready_check_timeout_sec
:
int
=
600
,
ready_check_timeout_sec
:
int
=
600
,
):
):
task_type
=
(
task_type
=
(
TaskType
.
EMBEDDING
if
api_url
.
endswith
(
"/v1/embeddings"
)
else
TaskType
.
EMBEDDING
TaskType
.
GENERATION
)
if
api_url
.
endswith
(
"/v1/embeddings"
)
else
TaskType
.
GENERATION
)
if
endpoint_type
in
ASYNC_REQUEST_FUNCS
:
if
endpoint_type
in
ASYNC_REQUEST_FUNCS
:
if
task_type
==
TaskType
.
EMBEDDING
:
if
task_type
==
TaskType
.
EMBEDDING
:
request_func
=
ASYNC_REQUEST_FUNCS
[
"openai-embeddings"
]
request_func
=
ASYNC_REQUEST_FUNCS
[
"openai-embeddings"
]
...
@@ -435,14 +498,10 @@ async def benchmark(
...
@@ -435,14 +498,10 @@ async def benchmark(
input_requests
[
0
].
multi_modal_data
,
input_requests
[
0
].
multi_modal_data
,
)
)
assert
(
assert
(
test_mm_content
is
None
or
isinstance
(
test_mm_content
,
dict
)
test_mm_content
is
None
or
(
isinstance
(
test_mm_content
,
list
)
or
isinstance
(
test_mm_content
,
dict
)
and
all
(
isinstance
(
item
,
dict
)
for
item
in
test_mm_content
))
or
(
),
"multi_modal_data must be a dict or list[dict]"
isinstance
(
test_mm_content
,
list
)
and
all
(
isinstance
(
item
,
dict
)
for
item
in
test_mm_content
)
)
),
"multi_modal_data must be a dict or list[dict]"
test_input
=
RequestFuncInput
(
test_input
=
RequestFuncInput
(
model
=
model_id
,
model
=
model_id
,
model_name
=
model_name
,
model_name
=
model_name
,
...
@@ -488,13 +547,13 @@ async def benchmark(
...
@@ -488,13 +547,13 @@ async def benchmark(
ignore_eos
=
ignore_eos
,
ignore_eos
=
ignore_eos
,
extra_headers
=
extra_headers
,
extra_headers
=
extra_headers
,
extra_body
=
extra_body
)
extra_body
=
extra_body
)
profile_output
=
await
request_func
(
profile_output
=
await
request_func
(
request_func_input
=
profile_input
,
request_func_input
=
profile_input
,
session
=
session
)
session
=
session
)
if
profile_output
.
success
:
if
profile_output
.
success
:
print
(
"Profiler started"
)
print
(
"Profiler started"
)
distribution
=
(
"Poisson process"
if
burstiness
==
1.0
distribution
=
(
"Poisson process"
else
"Gamma distribution"
)
if
burstiness
==
1.0
else
"Gamma distribution"
)
if
ramp_up_strategy
is
not
None
:
if
ramp_up_strategy
is
not
None
:
print
(
f
"Traffic ramp-up strategy:
{
ramp_up_strategy
}
."
)
print
(
f
"Traffic ramp-up strategy:
{
ramp_up_strategy
}
."
)
...
@@ -562,18 +621,20 @@ async def benchmark(
...
@@ -562,18 +621,20 @@ async def benchmark(
req_lora_module
=
next
(
lora_modules
)
req_lora_module
=
next
(
lora_modules
)
req_model_id
,
req_model_name
=
req_lora_module
,
req_lora_module
req_model_id
,
req_model_name
=
req_lora_module
,
req_lora_module
request_func_input
=
RequestFuncInput
(
model
=
req_model_id
,
request_func_input
=
RequestFuncInput
(
model_name
=
req_model_name
,
model
=
req_model_id
,
prompt
=
prompt
,
model_name
=
req_model_name
,
api_url
=
api_url
,
prompt
=
prompt
,
prompt_len
=
prompt_len
,
api_url
=
api_url
,
output_len
=
output_len
,
prompt_len
=
prompt_len
,
logprobs
=
logprobs
,
output_len
=
output_len
,
multi_modal_content
=
mm_content
,
logprobs
=
logprobs
,
ignore_eos
=
ignore_eos
,
multi_modal_content
=
mm_content
,
extra_headers
=
extra_headers
,
ignore_eos
=
ignore_eos
,
extra_body
=
extra_body
,
extra_headers
=
extra_headers
,
request_id
=
request_id
,)
extra_body
=
extra_body
,
request_id
=
request_id
,
)
tasks
.
append
(
tasks
.
append
(
asyncio
.
create_task
(
asyncio
.
create_task
(
limited_request_func
(
request_func_input
=
request_func_input
,
limited_request_func
(
request_func_input
=
request_func_input
,
...
@@ -615,19 +676,21 @@ async def benchmark(
...
@@ -615,19 +676,21 @@ async def benchmark(
benchmark_duration
))
benchmark_duration
))
print
(
"{:<40} {:<10}"
.
format
(
"Total input tokens:"
,
metrics
.
total_input
))
print
(
"{:<40} {:<10}"
.
format
(
"Total input tokens:"
,
metrics
.
total_input
))
if
isinstance
(
metrics
,
BenchmarkMetrics
):
if
isinstance
(
metrics
,
BenchmarkMetrics
):
print
(
"{:<40} {:<10}"
.
format
(
print
(
"{:<40} {:<10}"
.
format
(
"Total generated tokens:"
,
"Total generated tokens:"
,
metrics
.
total_output
))
metrics
.
total_output
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Request throughput (req/s):"
,
print
(
"{:<40} {:<10.2f}"
.
format
(
"Request throughput (req/s):"
,
metrics
.
request_throughput
))
metrics
.
request_throughput
))
if
goodput_config_dict
:
if
goodput_config_dict
:
print
(
"{:<40} {:<10.2f}"
.
format
(
"Request goodput (req/s):"
,
print
(
"{:<40} {:<10.2f}"
.
format
(
"Request goodput (req/s):"
,
metrics
.
request_goodput
))
metrics
.
request_goodput
))
if
isinstance
(
metrics
,
BenchmarkMetrics
):
if
isinstance
(
metrics
,
BenchmarkMetrics
):
print
(
print
(
"{:<40} {:<10.2f}"
.
format
(
"Output token throughput (tok/s):"
,
"{:<40} {:<10.2f}"
.
format
(
metrics
.
output_throughput
))
"Output token throughput (tok/s):"
,
metrics
.
output_throughput
print
(
"{:<40} {:<10.2f}"
.
format
(
)
"Peak output token throughput (tok/s):"
,
)
metrics
.
max_output_tokens_per_s
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Peak concurrent requests:"
,
metrics
.
max_concurrent_requests
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Total Token throughput (tok/s):"
,
print
(
"{:<40} {:<10.2f}"
.
format
(
"Total Token throughput (tok/s):"
,
metrics
.
total_token_throughput
))
metrics
.
total_token_throughput
))
...
@@ -648,6 +711,8 @@ async def benchmark(
...
@@ -648,6 +711,8 @@ async def benchmark(
"itls"
:
[
output
.
itl
for
output
in
outputs
],
"itls"
:
[
output
.
itl
for
output
in
outputs
],
"generated_texts"
:
[
output
.
generated_text
for
output
in
outputs
],
"generated_texts"
:
[
output
.
generated_text
for
output
in
outputs
],
"errors"
:
[
output
.
error
for
output
in
outputs
],
"errors"
:
[
output
.
error
for
output
in
outputs
],
"max_output_tokens_per_s"
:
metrics
.
max_output_tokens_per_s
,
"max_concurrent_requests"
:
metrics
.
max_concurrent_requests
,
}
}
else
:
else
:
result
=
{
result
=
{
...
@@ -697,8 +762,8 @@ async def benchmark(
...
@@ -697,8 +762,8 @@ async def benchmark(
if
task_type
==
TaskType
.
GENERATION
:
if
task_type
==
TaskType
.
GENERATION
:
process_one_metric
(
"ttft"
,
"TTFT"
,
"Time to First Token"
)
process_one_metric
(
"ttft"
,
"TTFT"
,
"Time to First Token"
)
process_one_metric
(
process_one_metric
(
"tpot"
,
"TPOT"
,
"tpot"
,
"TPOT"
,
"Time per Output Token (excl. 1st token)"
)
"Time per Output Token (excl. 1st token)"
)
process_one_metric
(
"itl"
,
"ITL"
,
"Inter-token Latency"
)
process_one_metric
(
"itl"
,
"ITL"
,
"Inter-token Latency"
)
process_one_metric
(
"e2el"
,
"E2EL"
,
"End-to-end Latency"
)
process_one_metric
(
"e2el"
,
"E2EL"
,
"End-to-end Latency"
)
...
@@ -714,8 +779,8 @@ async def benchmark(
...
@@ -714,8 +779,8 @@ async def benchmark(
output_len
=
test_output_len
,
output_len
=
test_output_len
,
logprobs
=
logprobs
,
logprobs
=
logprobs
,
)
)
profile_output
=
await
request_func
(
profile_output
=
await
request_func
(
request_func_input
=
profile_input
,
request_func_input
=
profile_input
,
session
=
session
)
session
=
session
)
if
profile_output
.
success
:
if
profile_output
.
success
:
print
(
"Profiler stopped"
)
print
(
"Profiler stopped"
)
...
@@ -851,7 +916,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
...
@@ -851,7 +916,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser
.
add_argument
(
parser
.
add_argument
(
"--tokenizer"
,
"--tokenizer"
,
type
=
str
,
type
=
str
,
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
# noqa: E501
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
# noqa: E501
)
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -982,7 +1048,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
...
@@ -982,7 +1048,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
help
=
"Specify the prefix of request id."
,
help
=
"Specify the prefix of request id."
,
)
)
sampling_group
=
parser
.
add_argument_group
(
"sampling parameters"
)
sampling_group
=
parser
.
add_argument_group
(
"sampling parameters"
)
sampling_group
.
add_argument
(
sampling_group
.
add_argument
(
"--top-p"
,
"--top-p"
,
...
@@ -1047,8 +1112,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
...
@@ -1047,8 +1112,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
help
=
"The ramp-up strategy. This would be used to "
help
=
"The ramp-up strategy. This would be used to "
"ramp up the request rate from initial RPS to final "
"ramp up the request rate from initial RPS to final "
"RPS rate (specified by --ramp-up-start-rps and "
"RPS rate (specified by --ramp-up-start-rps and "
"--ramp-up-end-rps.) over the duration of the benchmark."
"--ramp-up-end-rps.) over the duration of the benchmark."
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--ramp-up-start-rps"
,
"--ramp-up-start-rps"
,
type
=
int
,
type
=
int
,
...
@@ -1087,13 +1151,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
...
@@ -1087,13 +1151,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
raise
ValueError
(
raise
ValueError
(
"When using ramp-up, do not specify --request-rate. "
"When using ramp-up, do not specify --request-rate. "
"The request rate will be controlled by ramp-up parameters. "
"The request rate will be controlled by ramp-up parameters. "
"Please remove the --request-rate argument."
"Please remove the --request-rate argument."
)
)
if
args
.
ramp_up_start_rps
is
None
or
args
.
ramp_up_end_rps
is
None
:
if
args
.
ramp_up_start_rps
is
None
or
args
.
ramp_up_end_rps
is
None
:
raise
ValueError
(
raise
ValueError
(
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
"--ramp-up-end-rps must be specified"
"--ramp-up-end-rps must be specified"
)
)
if
args
.
ramp_up_start_rps
<
0
or
args
.
ramp_up_end_rps
<
0
:
if
args
.
ramp_up_start_rps
<
0
or
args
.
ramp_up_end_rps
<
0
:
raise
ValueError
(
"Ramp-up start and end RPS must be non-negative"
)
raise
ValueError
(
"Ramp-up start and end RPS must be non-negative"
)
if
args
.
ramp_up_start_rps
>
args
.
ramp_up_end_rps
:
if
args
.
ramp_up_start_rps
>
args
.
ramp_up_end_rps
:
...
@@ -1127,8 +1189,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
...
@@ -1127,8 +1189,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
headers
[
kvstring
[
0
].
strip
()]
=
kvstring
[
1
].
strip
()
headers
[
kvstring
[
0
].
strip
()]
=
kvstring
[
1
].
strip
()
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Invalid header format. Please use KEY=VALUE format."
"Invalid header format. Please use KEY=VALUE format."
)
)
tokenizer
=
get_tokenizer
(
tokenizer_id
,
tokenizer
=
get_tokenizer
(
tokenizer_id
,
tokenizer_mode
=
tokenizer_mode
,
tokenizer_mode
=
tokenizer_mode
,
...
@@ -1215,8 +1276,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
...
@@ -1215,8 +1276,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
result_json
[
kvstring
[
0
].
strip
()]
=
kvstring
[
1
].
strip
()
result_json
[
kvstring
[
0
].
strip
()]
=
kvstring
[
1
].
strip
()
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Invalid metadata format. Please use KEY=VALUE format."
"Invalid metadata format. Please use KEY=VALUE format."
)
)
# Traffic
# Traffic
result_json
[
"request_rate"
]
=
(
args
.
request_rate
if
args
.
request_rate
result_json
[
"request_rate"
]
=
(
args
.
request_rate
if
args
.
request_rate
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment