Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
13387e6b
Unverified
Commit
13387e6b
authored
Jan 17, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jan 17, 2025
Browse files
Multi-turn benchmark for hierarchical caching (#2942)
parent
120c3634
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
334 additions
and
0 deletions
+334
-0
benchmark/hicache/bench_multiturn.py
benchmark/hicache/bench_multiturn.py
+334
-0
No files found.
benchmark/hicache/bench_multiturn.py
0 → 100644
View file @
13387e6b
import
argparse
import
asyncio
import
json
import
queue
import
random
import
threading
import
time
from
typing
import
Optional
import
aiohttp
import
requests
from
tqdm.asyncio
import
tqdm
from
sglang.bench_serving
import
(
RequestFuncOutput
,
get_tokenizer
,
remove_prefix
,
sample_random_requests
,
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Script to benchmark concurrent requests to a server."
)
parser
.
add_argument
(
"--num-clients"
,
type
=
int
,
default
=
200
,
help
=
"Number of concurrent clients"
,
)
parser
.
add_argument
(
"--request-length"
,
type
=
int
,
default
=
512
,
help
=
"Length of each new request"
,
)
parser
.
add_argument
(
"--output-length"
,
type
=
int
,
default
=
64
,
help
=
"Length of each output"
,
)
parser
.
add_argument
(
"--num-rounds"
,
type
=
int
,
default
=
5
,
help
=
"Number of rounds per client"
,
)
parser
.
add_argument
(
"--distribution"
,
type
=
str
,
default
=
"poisson"
,
choices
=
[
"poisson"
,
"uniform"
],
help
=
"Distribution type for request intervals (poisson or uniform)"
,
)
parser
.
add_argument
(
"--request-rate"
,
type
=
float
,
default
=
1.0
,
help
=
"Average number of requests per second"
,
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"Server hostname or IP (default: localhost)"
,
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
,
help
=
"Server port (default: 30000)"
,
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"meta-llama/Llama-3.1-8B-Instruct"
,
help
=
"model path compatible with Hugging Face Transformers"
,
)
return
parser
.
parse_args
()
async
def
async_request_sglang_generate
(
payload
,
url
,
pbar
:
Optional
[
tqdm
]
=
None
,
):
"""
Sends a streaming request to the server. Gathers text token-by-token.
"""
async
with
aiohttp
.
ClientSession
()
as
session
:
headers
=
{}
generated_text
=
""
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
output
=
RequestFuncOutput
()
try
:
async
with
session
.
post
(
url
=
url
,
json
=
payload
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
remove_prefix
(
chunk_bytes
.
decode
(
"utf-8"
),
"data: "
)
latency
=
time
.
perf_counter
()
-
st
if
chunk
==
"[DONE]"
:
pass
else
:
data
=
json
.
loads
(
chunk
)
if
data
[
"text"
]:
timestamp
=
time
.
perf_counter
()
# First token
if
ttft
==
0.0
:
ttft
=
time
.
perf_counter
()
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
most_recent_timestamp
=
timestamp
generated_text
=
data
[
"text"
]
output
.
generated_text
=
generated_text
output
.
success
=
True
output
.
latency
=
latency
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
as
e
:
output
.
success
=
False
output
.
error
=
str
(
e
)
print
(
f
"Request failed:
{
e
}
"
)
if
pbar
:
pbar
.
update
(
1
)
return
output
def
gen_payload
(
prompt
,
output_len
):
payload
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0.0
,
"max_new_tokens"
:
output_len
,
"ignore_eos"
:
True
,
},
"stream"
:
True
,
"lora_path"
:
""
,
"return_logprob"
:
False
,
"logprob_start_len"
:
-
1
,
}
return
payload
class
ReadyQueue
:
"""
Thread-safe queue that can pop requests in different orders based on given policy.
"""
def
__init__
(
self
,
init_requests
=
None
,
policy
=
"random"
):
self
.
lock
=
threading
.
Lock
()
self
.
requests
=
init_requests
or
[]
self
.
policy
=
policy
def
append
(
self
,
item
):
with
self
.
lock
:
self
.
requests
.
append
(
item
)
def
pop
(
self
):
with
self
.
lock
:
if
not
self
.
requests
:
return
None
if
self
.
policy
==
"random"
:
index
=
random
.
randrange
(
len
(
self
.
requests
))
return
self
.
requests
.
pop
(
index
)
elif
self
.
policy
==
"fifo"
:
return
self
.
requests
.
pop
(
0
)
else
:
# todo, varying thinking time of clients
raise
ValueError
(
f
"
{
self
.
policy
}
not implemented"
)
class
WorkloadGenerator
:
def
__init__
(
self
,
args
):
# Construct the base URL for requests
self
.
url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
self
.
tokenizer
=
get_tokenizer
(
args
.
model
)
self
.
distribution
=
args
.
distribution
self
.
request_rate
=
args
.
request_rate
self
.
start_time
=
None
self
.
finished_time
=
None
self
.
candidate_inputs
=
sample_random_requests
(
input_len
=
args
.
request_length
,
output_len
=
args
.
output_length
,
num_prompts
=
args
.
num_clients
*
args
.
num_rounds
,
range_ratio
=
1.0
,
tokenizer
=
self
.
tokenizer
,
dataset_path
=
""
,
)
self
.
candidate_inputs
=
[
i
[
0
]
for
i
in
self
.
candidate_inputs
]
init_requests
=
[
(
i
,
gen_payload
(
self
.
candidate_inputs
[
i
],
args
.
output_length
))
for
i
in
range
(
args
.
num_clients
)
]
self
.
client_records
=
{
i
:
{
"round"
:
0
,
"history"
:
init_requests
[
i
][
1
][
"text"
]}
for
i
in
range
(
args
.
num_clients
)
}
self
.
ready_queue
=
ReadyQueue
(
init_requests
=
init_requests
)
self
.
candidate_inputs
=
self
.
candidate_inputs
[
args
.
num_clients
:]
self
.
response_queue
=
queue
.
Queue
()
self
.
pbar
=
tqdm
(
total
=
args
.
num_clients
*
args
.
num_rounds
)
self
.
performance_metrics
=
{
"ttft"
:
[],
"latency"
:
[]}
async
def
handle_request
(
self
,
item
):
try
:
client_id
,
payload
=
item
response
=
await
async_request_sglang_generate
(
payload
,
self
.
url
,
self
.
pbar
)
if
self
.
pbar
.
n
==
self
.
pbar
.
total
:
self
.
finished_time
=
time
.
time
()
self
.
response_queue
.
put
((
client_id
,
response
))
except
Exception
as
e
:
print
(
f
"Request failed:
{
e
}
"
)
def
request_sender
(
self
):
async
def
request_loop
():
while
True
:
# Calculate Poisson-distributed wait time
if
self
.
distribution
==
"poisson"
:
sleep_time
=
random
.
expovariate
(
self
.
request_rate
)
elif
self
.
distribution
==
"uniform"
:
avg_interval
=
(
1.0
/
self
.
request_rate
if
self
.
request_rate
>
0
else
1.0
)
sleep_time
=
random
.
uniform
(
0
,
2
*
avg_interval
)
else
:
raise
ValueError
(
"Invalid distribution type"
)
await
asyncio
.
sleep
(
sleep_time
)
# Wait before sending the next request
new_request
=
self
.
ready_queue
.
pop
()
# Submit async request
if
new_request
:
asyncio
.
create_task
(
self
.
handle_request
(
new_request
))
else
:
if
self
.
pbar
.
n
==
self
.
pbar
.
total
:
break
# Create and run the event loop for asynchronous requests
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
loop
.
run_until_complete
(
request_loop
())
loop
.
close
()
def
response_handler
(
self
):
while
True
:
try
:
client_id
,
response
=
self
.
response_queue
.
get
(
timeout
=
10
)
# Block until response is available
if
not
response
.
success
:
raise
ValueError
(
f
"Request failed with error:
{
response
.
error
}
"
)
self
.
client_records
[
client_id
][
"history"
]
+=
response
.
generated_text
self
.
client_records
[
client_id
][
"round"
]
+=
1
self
.
performance_metrics
[
"ttft"
].
append
(
response
.
ttft
)
self
.
performance_metrics
[
"latency"
].
append
(
response
.
latency
)
if
self
.
client_records
[
client_id
][
"round"
]
<
args
.
num_rounds
:
self
.
client_records
[
client_id
][
"history"
]
+=
self
.
candidate_inputs
.
pop
()
self
.
ready_queue
.
append
(
(
client_id
,
gen_payload
(
self
.
client_records
[
client_id
][
"history"
],
args
.
output_length
,
),
)
)
except
queue
.
Empty
:
if
self
.
pbar
.
n
==
self
.
pbar
.
total
:
break
def
run
(
self
):
request_thread
=
threading
.
Thread
(
target
=
self
.
request_sender
,
daemon
=
True
)
response_thread
=
threading
.
Thread
(
target
=
self
.
response_handler
,
daemon
=
True
)
self
.
start_time
=
time
.
time
()
request_thread
.
start
()
response_thread
.
start
()
request_thread
.
join
()
response_thread
.
join
()
self
.
pbar
.
close
()
print
(
"All requests completed."
)
print
(
"Performance metrics summary:"
)
print
(
f
" Total requests:
{
len
(
self
.
performance_metrics
[
'ttft'
])
}
at
{
self
.
request_rate
}
requests per second"
)
print
(
f
" Average TTFT:
{
sum
(
self
.
performance_metrics
[
'ttft'
])
/
len
(
self
.
performance_metrics
[
'ttft'
]):.
2
f
}
"
)
print
(
f
" Median TTFT:
{
sorted
(
self
.
performance_metrics
[
'ttft'
])[
len
(
self
.
performance_metrics
[
'ttft'
])
//
2
]:.
2
f
}
"
)
print
(
f
" Average latency:
{
sum
(
self
.
performance_metrics
[
'latency'
])
/
len
(
self
.
performance_metrics
[
'latency'
]):.
2
f
}
"
)
print
(
f
" Median latency:
{
sorted
(
self
.
performance_metrics
[
'latency'
])[
len
(
self
.
performance_metrics
[
'latency'
])
//
2
]:.
2
f
}
"
)
throughput
=
self
.
pbar
.
total
/
(
self
.
finished_time
-
self
.
start_time
)
print
(
f
"Throughput:
{
throughput
:.
2
f
}
requests per second"
)
if
__name__
==
"__main__"
:
args
=
parse_args
()
flush_cache_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/flush_cache"
for
request_rate
in
range
(
1
,
41
,
2
):
args
.
request_rate
=
request_rate
requests
.
post
(
flush_cache_url
)
WorkloadGenerator
(
args
).
run
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment