Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6f547829
Unverified
Commit
6f547829
authored
Aug 03, 2025
by
Seiji Eicher
Committed by
GitHub
Aug 03, 2025
Browse files
Use `aiohttp` connection pool for benchmarking (#21981)
Signed-off-by:
Seiji Eicher
<
seiji@anyscale.com
>
parent
6a39ba85
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
276 additions
and
247 deletions
+276
-247
vllm/benchmarks/lib/endpoint_request_func.py
vllm/benchmarks/lib/endpoint_request_func.py
+238
-241
vllm/benchmarks/lib/ready_checker.py
vllm/benchmarks/lib/ready_checker.py
+3
-1
vllm/benchmarks/serve.py
vllm/benchmarks/serve.py
+35
-5
No files found.
vllm/benchmarks/lib/endpoint_request_func.py
View file @
6f547829
...
@@ -50,6 +50,7 @@ class RequestFuncOutput:
...
@@ -50,6 +50,7 @@ class RequestFuncOutput:
async
def
async_request_openai_completions
(
async
def
async_request_openai_completions
(
request_func_input
:
RequestFuncInput
,
request_func_input
:
RequestFuncInput
,
session
:
aiohttp
.
ClientSession
,
pbar
:
Optional
[
tqdm
]
=
None
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
)
->
RequestFuncOutput
:
"""The async request function for the OpenAI Completions API.
"""The async request function for the OpenAI Completions API.
...
@@ -66,96 +67,94 @@ async def async_request_openai_completions(
...
@@ -66,96 +67,94 @@ async def async_request_openai_completions(
(
"completions"
,
"profile"
)
(
"completions"
,
"profile"
)
),
"OpenAI Completions API URL must end with 'completions' or 'profile'."
),
"OpenAI Completions API URL must end with 'completions' or 'profile'."
async
with
aiohttp
.
ClientSession
(
trust_env
=
True
,
payload
=
{
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
"model"
:
request_func_input
.
model_name
\
payload
=
{
if
request_func_input
.
model_name
else
request_func_input
.
model
,
"model"
:
request_func_input
.
model_name
\
"prompt"
:
request_func_input
.
prompt
,
if
request_func_input
.
model_name
else
request_func_input
.
model
,
"temperature"
:
0.0
,
"prompt"
:
request_func_input
.
prompt
,
"repetition_penalty"
:
1.0
,
"temperature"
:
0.0
,
"max_tokens"
:
request_func_input
.
output_len
,
"repetition_penalty"
:
1.0
,
"logprobs"
:
request_func_input
.
logprobs
,
"max_tokens"
:
request_func_input
.
output_len
,
"stream"
:
True
,
"logprobs"
:
request_func_input
.
logprobs
,
"stream_options"
:
{
"stream"
:
True
,
"include_usage"
:
True
,
"stream_options"
:
{
},
"include_usage"
:
True
,
}
},
if
request_func_input
.
ignore_eos
:
}
payload
[
"ignore_eos"
]
=
request_func_input
.
ignore_eos
if
request_func_input
.
ignore_eos
:
if
request_func_input
.
extra_body
:
payload
[
"ignore_eos"
]
=
request_func_input
.
ignore_eos
payload
.
update
(
request_func_input
.
extra_body
)
if
request_func_input
.
extra_body
:
headers
=
{
payload
.
update
(
request_func_input
.
extra_body
)
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
headers
=
{
}
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
}
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
generated_text
=
""
st
=
time
.
perf_counter
()
generated_text
=
""
most_recent_timestamp
=
st
st
=
time
.
perf_counter
()
try
:
most_recent_timestamp
=
st
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
try
:
headers
=
headers
)
as
response
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
if
response
.
status
==
200
:
headers
=
headers
)
as
response
:
first_chunk_received
=
False
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
first_chunk_received
=
False
chunk_bytes
=
chunk_bytes
.
strip
()
async
for
chunk_bytes
in
response
.
content
:
if
not
chunk_bytes
:
chunk_bytes
=
chunk_bytes
.
strip
()
continue
if
not
chunk_bytes
:
chunk_bytes
=
chunk_bytes
.
decode
(
"utf-8"
)
continue
# NOTE: SSE comments (often used as pings) start with
chunk_bytes
=
chunk_bytes
.
decode
(
"utf-8"
)
# a colon. These are not JSON data payload and should
# NOTE: SSE comments (often used as pings) start with
# be skipped.
# a colon. These are not JSON data payload and should
if
chunk_bytes
.
startswith
(
":"
):
# be skipped.
continue
if
chunk_bytes
.
startswith
(
":"
):
continue
chunk
=
chunk_bytes
.
removeprefix
(
"data: "
)
chunk
=
chunk_bytes
.
removeprefix
(
"data: "
)
if
chunk
!=
"[DONE]"
:
data
=
json
.
loads
(
chunk
)
if
chunk
!=
"[DONE]"
:
data
=
json
.
loads
(
chunk
)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# NOTE: Some completion API might have a last
# want to check a token was generated
# usage summary response without a token so we
if
choices
:
=
data
.
get
(
"choices"
):
# want to check a token was generated
# Note that text could be empty here
if
choices
:
=
data
.
get
(
"choices"
):
# e.g. for special tokens
# Note that text could be empty here
text
=
choices
[
0
].
get
(
"text"
)
# e.g. for special tokens
timestamp
=
time
.
perf_counter
()
text
=
choices
[
0
].
get
(
"text"
)
# First token
timestamp
=
time
.
perf_counter
()
if
not
first_chunk_received
:
# First token
first_chunk_received
=
True
if
not
first_chunk_received
:
ttft
=
time
.
perf_counter
()
-
st
first_chunk_received
=
True
output
.
ttft
=
ttft
ttft
=
time
.
perf_counter
()
-
st
output
.
ttft
=
ttft
# Decoding phase
# Decoding phase
else
:
else
:
output
.
itl
.
append
(
timestamp
-
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
most_recent_timestamp
)
most_recent_timestamp
=
timestamp
most_recent_timestamp
=
timestamp
generated_text
+=
text
or
""
generated_text
+=
text
or
""
elif
usage
:
=
data
.
get
(
"usage"
):
elif
usage
:
=
data
.
get
(
"usage"
):
output
.
output_tokens
=
usage
.
get
(
output
.
output_tokens
=
usage
.
get
(
"completion_tokens"
)
"completion_tokens"
)
if
first_chunk_received
:
if
first_chunk_received
:
output
.
success
=
True
output
.
success
=
True
else
:
output
.
success
=
False
output
.
error
=
(
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!"
)
output
.
generated_text
=
generated_text
output
.
latency
=
most_recent_timestamp
-
st
else
:
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
output
.
success
=
False
except
Exception
:
output
.
error
=
(
output
.
success
=
False
"Never received a valid chunk to calculate TTFT."
exc_info
=
sys
.
exc_info
()
"This response will be marked as failed!"
)
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
output
.
generated_text
=
generated_text
output
.
latency
=
most_recent_timestamp
-
st
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
if
pbar
:
pbar
.
update
(
1
)
pbar
.
update
(
1
)
...
@@ -164,45 +163,158 @@ async def async_request_openai_completions(
...
@@ -164,45 +163,158 @@ async def async_request_openai_completions(
async
def
async_request_openai_chat_completions
(
async
def
async_request_openai_chat_completions
(
request_func_input
:
RequestFuncInput
,
request_func_input
:
RequestFuncInput
,
session
:
aiohttp
.
ClientSession
,
pbar
:
Optional
[
tqdm
]
=
None
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
((
"chat/completions"
,
"profile"
)),
(
assert
api_url
.
endswith
((
"chat/completions"
,
"profile"
)),
(
"OpenAI Chat Completions API URL must end with 'chat/completions'."
)
"OpenAI Chat Completions API URL must end with 'chat/completions'."
)
async
with
aiohttp
.
ClientSession
(
trust_env
=
True
,
content
=
[{
"type"
:
"text"
,
"text"
:
request_func_input
.
prompt
}]
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
if
request_func_input
.
multi_modal_content
:
content
=
[{
"type"
:
"text"
,
"text"
:
request_func_input
.
prompt
}]
content
.
append
(
request_func_input
.
multi_modal_content
)
if
request_func_input
.
multi_modal_content
:
payload
=
{
content
.
append
(
request_func_input
.
multi_modal_content
)
"model"
:
payload
=
{
request_func_input
.
model_name
"model"
:
if
request_func_input
.
model_name
else
request_func_input
.
model
,
request_func_input
.
model_name
"messages"
:
[
if
request_func_input
.
model_name
else
request_func_input
.
model
,
{
"messages"
:
[
"role"
:
"user"
,
{
"content"
:
content
"role"
:
"user"
,
"content"
:
content
},
],
"temperature"
:
0.0
,
"max_completion_tokens"
:
request_func_input
.
output_len
,
"stream"
:
True
,
"stream_options"
:
{
"include_usage"
:
True
,
},
},
}
],
if
request_func_input
.
ignore_eos
:
"temperature"
:
payload
[
"ignore_eos"
]
=
request_func_input
.
ignore_eos
0.0
,
if
request_func_input
.
extra_body
:
"max_completion_tokens"
:
payload
.
update
(
request_func_input
.
extra_body
)
request_func_input
.
output_len
,
headers
=
{
"stream"
:
"Content-Type"
:
"application/json"
,
True
,
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
"stream_options"
:
{
}
"include_usage"
:
True
,
},
}
if
request_func_input
.
ignore_eos
:
payload
[
"ignore_eos"
]
=
request_func_input
.
ignore_eos
if
request_func_input
.
extra_body
:
payload
.
update
(
request_func_input
.
extra_body
)
headers
=
{
"Content-Type"
:
"application/json"
,
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
}
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
generated_text
=
""
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk_bytes
=
chunk_bytes
.
decode
(
"utf-8"
)
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if
chunk_bytes
.
startswith
(
":"
):
continue
chunk
=
chunk_bytes
.
removeprefix
(
"data: "
)
if
chunk
!=
"[DONE]"
:
timestamp
=
time
.
perf_counter
()
data
=
json
.
loads
(
chunk
)
if
choices
:
=
data
.
get
(
"choices"
):
content
=
choices
[
0
][
"delta"
].
get
(
"content"
)
# First token
if
ttft
==
0.0
:
ttft
=
timestamp
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
generated_text
+=
content
or
""
elif
usage
:
=
data
.
get
(
"usage"
):
output
.
output_tokens
=
usage
.
get
(
"completion_tokens"
)
most_recent_timestamp
=
timestamp
output
.
generated_text
=
generated_text
output
.
success
=
True
output
.
latency
=
most_recent_timestamp
-
st
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
async
def
async_request_openai_audio
(
request_func_input
:
RequestFuncInput
,
session
:
aiohttp
.
ClientSession
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
# Lazy import without PlaceholderModule to avoid vllm dep.
import
soundfile
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
((
"transcriptions"
,
"translations"
)),
(
"OpenAI Chat Completions API URL must end with 'transcriptions' "
)
"or `translations`."
content
=
[{
"type"
:
"text"
,
"text"
:
request_func_input
.
prompt
}]
payload
=
{
"model"
:
request_func_input
.
model_name
if
request_func_input
.
model_name
else
request_func_input
.
model
,
"temperature"
:
0.0
,
"max_completion_tokens"
:
request_func_input
.
output_len
,
"stream"
:
True
,
"language"
:
"en"
,
# Flattened due to multipart/form-data
"stream_include_usage"
:
True
,
"stream_continuous_usage_stats"
:
True
,
}
if
request_func_input
.
extra_body
:
payload
.
update
(
request_func_input
.
extra_body
)
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
}
# Send audio file
def
to_bytes
(
y
,
sr
):
buffer
=
io
.
BytesIO
()
soundfile
.
write
(
buffer
,
y
,
sr
,
format
=
"WAV"
)
buffer
.
seek
(
0
)
return
buffer
with
to_bytes
(
*
request_func_input
.
multi_modal_content
[
"audio"
])
as
f
:
form
=
aiohttp
.
FormData
()
form
.
add_field
(
"file"
,
f
,
content_type
=
"audio/wav"
)
for
key
,
value
in
payload
.
items
():
form
.
add_field
(
key
,
str
(
value
))
output
=
RequestFuncOutput
()
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
output
.
prompt_len
=
request_func_input
.
prompt_len
...
@@ -212,28 +324,24 @@ async def async_request_openai_chat_completions(
...
@@ -212,28 +324,24 @@ async def async_request_openai_chat_completions(
st
=
time
.
perf_counter
()
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
most_recent_timestamp
=
st
try
:
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
async
with
session
.
post
(
url
=
api_url
,
data
=
form
,
headers
=
headers
)
as
response
:
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
if
not
chunk_bytes
:
continue
continue
chunk_bytes
=
chunk_bytes
.
decode
(
"utf-8"
)
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if
chunk_bytes
.
startswith
(
":"
):
continue
chunk
=
chunk_bytes
.
removeprefix
(
"data: "
)
chunk
=
chunk_bytes
.
decode
(
"utf-8"
).
removeprefix
(
"data: "
)
if
chunk
!=
"[DONE]"
:
if
chunk
!=
"[DONE]"
:
timestamp
=
time
.
perf_counter
()
timestamp
=
time
.
perf_counter
()
data
=
json
.
loads
(
chunk
)
data
=
json
.
loads
(
chunk
)
if
choices
:
=
data
.
get
(
"choices"
):
if
choices
:
=
data
.
get
(
"choices"
):
content
=
choices
[
0
][
"delta"
].
get
(
"content"
)
content
=
choices
[
0
][
"delta"
].
get
(
"content"
)
# First token
# First token
if
ttft
==
0.0
:
if
ttft
==
0.0
:
ttft
=
timestamp
-
st
ttft
=
timestamp
-
st
...
@@ -241,8 +349,8 @@ async def async_request_openai_chat_completions(
...
@@ -241,8 +349,8 @@ async def async_request_openai_chat_completions(
# Decoding phase
# Decoding phase
else
:
else
:
output
.
itl
.
append
(
timestamp
-
output
.
itl
.
append
(
most_recent_timestamp
)
timestamp
-
most_recent_timestamp
)
generated_text
+=
content
or
""
generated_text
+=
content
or
""
elif
usage
:
=
data
.
get
(
"usage"
):
elif
usage
:
=
data
.
get
(
"usage"
):
...
@@ -267,117 +375,6 @@ async def async_request_openai_chat_completions(
...
@@ -267,117 +375,6 @@ async def async_request_openai_chat_completions(
return
output
return
output
async
def
async_request_openai_audio
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
# Lazy import without PlaceholderModule to avoid vllm dep.
import
soundfile
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
((
"transcriptions"
,
"translations"
)),
(
"OpenAI Chat Completions API URL must end with 'transcriptions' "
)
"or `translations`."
async
with
aiohttp
.
ClientSession
(
trust_env
=
True
,
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
content
=
[{
"type"
:
"text"
,
"text"
:
request_func_input
.
prompt
}]
payload
=
{
"model"
:
request_func_input
.
model_name
if
request_func_input
.
model_name
else
request_func_input
.
model
,
"temperature"
:
0.0
,
"max_completion_tokens"
:
request_func_input
.
output_len
,
"stream"
:
True
,
"language"
:
"en"
,
# Flattened due to multipart/form-data
"stream_include_usage"
:
True
,
"stream_continuous_usage_stats"
:
True
,
}
if
request_func_input
.
extra_body
:
payload
.
update
(
request_func_input
.
extra_body
)
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
}
# Send audio file
def
to_bytes
(
y
,
sr
):
buffer
=
io
.
BytesIO
()
soundfile
.
write
(
buffer
,
y
,
sr
,
format
=
"WAV"
)
buffer
.
seek
(
0
)
return
buffer
with
to_bytes
(
*
request_func_input
.
multi_modal_content
[
"audio"
])
as
f
:
form
=
aiohttp
.
FormData
()
form
.
add_field
(
"file"
,
f
,
content_type
=
"audio/wav"
)
for
key
,
value
in
payload
.
items
():
form
.
add_field
(
key
,
str
(
value
))
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
generated_text
=
""
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
data
=
form
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
chunk_bytes
.
decode
(
"utf-8"
).
removeprefix
(
"data: "
)
if
chunk
!=
"[DONE]"
:
timestamp
=
time
.
perf_counter
()
data
=
json
.
loads
(
chunk
)
if
choices
:
=
data
.
get
(
"choices"
):
content
=
choices
[
0
][
"delta"
].
get
(
"content"
)
# First token
if
ttft
==
0.0
:
ttft
=
timestamp
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
generated_text
+=
content
or
""
elif
usage
:
=
data
.
get
(
"usage"
):
output
.
output_tokens
=
usage
.
get
(
"completion_tokens"
)
most_recent_timestamp
=
timestamp
output
.
generated_text
=
generated_text
output
.
success
=
True
output
.
latency
=
most_recent_timestamp
-
st
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
# TODO: Add more request functions for different API protocols.
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS
=
{
ASYNC_REQUEST_FUNCS
=
{
"vllm"
:
async_request_openai_completions
,
"vllm"
:
async_request_openai_completions
,
...
...
vllm/benchmarks/lib/ready_checker.py
View file @
6f547829
...
@@ -14,6 +14,7 @@ from .endpoint_request_func import RequestFuncInput, RequestFuncOutput
...
@@ -14,6 +14,7 @@ from .endpoint_request_func import RequestFuncInput, RequestFuncOutput
async
def
wait_for_endpoint
(
async
def
wait_for_endpoint
(
request_func
,
request_func
,
test_input
:
RequestFuncInput
,
test_input
:
RequestFuncInput
,
session
:
aiohttp
.
ClientSession
,
timeout_seconds
:
int
=
600
,
timeout_seconds
:
int
=
600
,
retry_interval
:
int
=
5
,
retry_interval
:
int
=
5
,
)
->
RequestFuncOutput
:
)
->
RequestFuncOutput
:
...
@@ -55,7 +56,8 @@ async def wait_for_endpoint(
...
@@ -55,7 +56,8 @@ async def wait_for_endpoint(
# ping the endpoint using request_func
# ping the endpoint using request_func
try
:
try
:
output
=
await
request_func
(
request_func_input
=
test_input
)
output
=
await
request_func
(
request_func_input
=
test_input
,
session
=
session
)
if
output
.
success
:
if
output
.
success
:
pbar
.
close
()
pbar
.
close
()
return
output
return
output
...
...
vllm/benchmarks/serve.py
View file @
6f547829
...
@@ -28,6 +28,7 @@ from dataclasses import dataclass
...
@@ -28,6 +28,7 @@ from dataclasses import dataclass
from
datetime
import
datetime
from
datetime
import
datetime
from
typing
import
Any
,
Literal
,
Optional
from
typing
import
Any
,
Literal
,
Optional
import
aiohttp
import
numpy
as
np
import
numpy
as
np
from
tqdm.asyncio
import
tqdm
from
tqdm.asyncio
import
tqdm
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
...
@@ -338,6 +339,24 @@ async def benchmark(
...
@@ -338,6 +339,24 @@ async def benchmark(
else
:
else
:
raise
ValueError
(
f
"Unknown endpoint_type:
{
endpoint_type
}
"
)
raise
ValueError
(
f
"Unknown endpoint_type:
{
endpoint_type
}
"
)
# Reuses connections across requests to reduce TLS handshake overhead.
connector
=
aiohttp
.
TCPConnector
(
limit
=
max_concurrency
or
0
,
limit_per_host
=
max_concurrency
or
0
,
ttl_dns_cache
=
300
,
use_dns_cache
=
True
,
keepalive_timeout
=
60
,
enable_cleanup_closed
=
True
,
force_close
=
False
,
ssl
=
(
"https://"
in
api_url
),
)
session
=
aiohttp
.
ClientSession
(
connector
=
connector
,
trust_env
=
True
,
timeout
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
),
)
print
(
"Starting initial single prompt test run..."
)
print
(
"Starting initial single prompt test run..."
)
test_prompt
,
test_prompt_len
,
test_output_len
,
test_mm_content
=
(
test_prompt
,
test_prompt_len
,
test_output_len
,
test_mm_content
=
(
input_requests
[
0
].
prompt
,
input_requests
[
0
].
prompt
,
...
@@ -361,7 +380,11 @@ async def benchmark(
...
@@ -361,7 +380,11 @@ async def benchmark(
)
)
test_output
=
await
wait_for_endpoint
(
test_output
=
await
wait_for_endpoint
(
request_func
,
test_input
,
timeout_seconds
=
ready_check_timeout_sec
)
request_func
,
test_input
,
session
,
timeout_seconds
=
ready_check_timeout_sec
,
)
if
not
test_output
.
success
:
if
not
test_output
.
success
:
raise
ValueError
(
raise
ValueError
(
"Initial test run failed - Please make sure benchmark arguments "
"Initial test run failed - Please make sure benchmark arguments "
...
@@ -386,7 +409,8 @@ async def benchmark(
...
@@ -386,7 +409,8 @@ async def benchmark(
multi_modal_content
=
test_mm_content
,
multi_modal_content
=
test_mm_content
,
ignore_eos
=
ignore_eos
,
ignore_eos
=
ignore_eos
,
extra_body
=
extra_body
)
extra_body
=
extra_body
)
profile_output
=
await
request_func
(
request_func_input
=
profile_input
)
profile_output
=
await
request_func
(
request_func_input
=
profile_input
,
session
=
session
)
if
profile_output
.
success
:
if
profile_output
.
success
:
print
(
"Profiler started"
)
print
(
"Profiler started"
)
...
@@ -412,12 +436,14 @@ async def benchmark(
...
@@ -412,12 +436,14 @@ async def benchmark(
semaphore
=
(
asyncio
.
Semaphore
(
max_concurrency
)
semaphore
=
(
asyncio
.
Semaphore
(
max_concurrency
)
if
max_concurrency
else
None
)
if
max_concurrency
else
None
)
async
def
limited_request_func
(
request_func_input
,
pbar
):
async
def
limited_request_func
(
request_func_input
,
session
,
pbar
):
if
semaphore
is
None
:
if
semaphore
is
None
:
return
await
request_func
(
request_func_input
=
request_func_input
,
return
await
request_func
(
request_func_input
=
request_func_input
,
session
=
session
,
pbar
=
pbar
)
pbar
=
pbar
)
async
with
semaphore
:
async
with
semaphore
:
return
await
request_func
(
request_func_input
=
request_func_input
,
return
await
request_func
(
request_func_input
=
request_func_input
,
session
=
session
,
pbar
=
pbar
)
pbar
=
pbar
)
benchmark_start_time
=
time
.
perf_counter
()
benchmark_start_time
=
time
.
perf_counter
()
...
@@ -469,6 +495,7 @@ async def benchmark(
...
@@ -469,6 +495,7 @@ async def benchmark(
tasks
.
append
(
tasks
.
append
(
asyncio
.
create_task
(
asyncio
.
create_task
(
limited_request_func
(
request_func_input
=
request_func_input
,
limited_request_func
(
request_func_input
=
request_func_input
,
session
=
session
,
pbar
=
pbar
)))
pbar
=
pbar
)))
outputs
:
list
[
RequestFuncOutput
]
=
await
asyncio
.
gather
(
*
tasks
)
outputs
:
list
[
RequestFuncOutput
]
=
await
asyncio
.
gather
(
*
tasks
)
...
@@ -580,9 +607,12 @@ async def benchmark(
...
@@ -580,9 +607,12 @@ async def benchmark(
output_len
=
test_output_len
,
output_len
=
test_output_len
,
logprobs
=
logprobs
,
logprobs
=
logprobs
,
)
)
profile_output
=
await
request_func
(
request_func_input
=
profile_input
)
profile_output
=
await
request_func
(
request_func_input
=
profile_input
,
session
=
session
)
if
profile_output
.
success
:
if
profile_output
.
success
:
print
(
"Profiler stopped"
)
print
(
"Profiler stopped"
)
await
session
.
close
()
return
result
return
result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment