Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9d4ca19d
Unverified
Commit
9d4ca19d
authored
Apr 19, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Apr 19, 2025
Browse files
[Misc] Benchmarks for audio models (#16505)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
2ef0dc53
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
199 additions
and
5 deletions
+199
-5
benchmarks/backend_request_func.py
benchmarks/backend_request_func.py
+107
-0
benchmarks/benchmark_dataset.py
benchmarks/benchmark_dataset.py
+80
-0
benchmarks/benchmark_serving.py
benchmarks/benchmark_serving.py
+11
-5
tests/entrypoints/openai/correctness/test_transcription_api_correctness.py
.../openai/correctness/test_transcription_api_correctness.py
+1
-0
No files found.
benchmarks/backend_request_func.py
View file @
9d4ca19d
# SPDX-License-Identifier: Apache-2.0
import
io
import
json
import
os
import
sys
...
...
@@ -32,6 +33,7 @@ class RequestFuncInput:
extra_body
:
Optional
[
dict
]
=
None
multi_modal_content
:
Optional
[
dict
]
=
None
ignore_eos
:
bool
=
False
language
:
Optional
[
str
]
=
None
@
dataclass
...
...
@@ -436,6 +438,110 @@ async def async_request_openai_chat_completions(
return
output
async
def
async_request_openai_audio
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
# Lazy import without PlaceholderModule to avoid vllm dep.
import
soundfile
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
(
"transcriptions"
,
"translations"
)),
"OpenAI Chat Completions API URL must end with 'transcriptions' "
"or `translations`."
async
with
aiohttp
.
ClientSession
(
trust_env
=
True
,
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
content
=
[{
"type"
:
"text"
,
"text"
:
request_func_input
.
prompt
}]
payload
=
{
"model"
:
request_func_input
.
model_name
\
if
request_func_input
.
model_name
else
request_func_input
.
model
,
"temperature"
:
0.0
,
"max_completion_tokens"
:
request_func_input
.
output_len
,
"stream"
:
True
,
"language"
:
"en"
,
# Flattened due to multipart/form-data
"stream_include_usage"
:
True
,
"stream_continuous_usage_stats"
:
True
}
if
request_func_input
.
extra_body
:
payload
.
update
(
request_func_input
.
extra_body
)
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
}
# Send audio file
def
to_bytes
(
y
,
sr
):
buffer
=
io
.
BytesIO
()
soundfile
.
write
(
buffer
,
y
,
sr
,
format
=
"WAV"
)
buffer
.
seek
(
0
)
return
buffer
with
to_bytes
(
*
request_func_input
.
multi_modal_content
[
'audio'
])
as
f
:
form
=
aiohttp
.
FormData
()
form
.
add_field
(
'file'
,
f
,
content_type
=
'audio/wav'
)
for
key
,
value
in
payload
.
items
():
form
.
add_field
(
key
,
str
(
value
))
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
generated_text
=
""
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
data
=
form
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
chunk_bytes
.
decode
(
"utf-8"
).
removeprefix
(
"data: "
)
if
chunk
!=
"[DONE]"
:
timestamp
=
time
.
perf_counter
()
data
=
json
.
loads
(
chunk
)
if
choices
:
=
data
.
get
(
"choices"
):
content
=
choices
[
0
][
"delta"
].
get
(
"content"
)
# First token
if
ttft
==
0.0
:
ttft
=
timestamp
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
generated_text
+=
content
or
""
elif
usage
:
=
data
.
get
(
"usage"
):
output
.
output_tokens
=
usage
.
get
(
"completion_tokens"
)
most_recent_timestamp
=
timestamp
output
.
generated_text
=
generated_text
output
.
success
=
True
output
.
latency
=
most_recent_timestamp
-
st
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
def
get_model
(
pretrained_model_name_or_path
:
str
)
->
str
:
if
os
.
getenv
(
'VLLM_USE_MODELSCOPE'
,
'False'
).
lower
()
==
'true'
:
from
modelscope
import
snapshot_download
...
...
@@ -493,6 +599,7 @@ ASYNC_REQUEST_FUNCS = {
"deepspeed-mii"
:
async_request_deepspeed_mii
,
"openai"
:
async_request_openai_completions
,
"openai-chat"
:
async_request_openai_chat_completions
,
"openai-audio"
:
async_request_openai_audio
,
"tensorrt-llm"
:
async_request_trt_llm
,
"scalellm"
:
async_request_openai_completions
,
"sglang"
:
async_request_openai_completions
,
...
...
benchmarks/benchmark_dataset.py
View file @
9d4ca19d
...
...
@@ -64,6 +64,7 @@ class SampleRequest:
class
BenchmarkDataset
(
ABC
):
DEFAULT_SEED
=
0
IS_MULTIMODAL
=
False
def
__init__
(
self
,
...
...
@@ -621,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset):
SUPPORTED_DATASET_PATHS
=
{
'lmms-lab/LLaVA-OneVision-Data'
,
'Aeala/ShareGPT_Vicuna_unfiltered'
}
IS_MULTIMODAL
=
True
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
...
...
@@ -685,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset):
"lmarena-ai/vision-arena-bench-v0.1"
:
lambda
x
:
x
[
"turns"
][
0
][
0
][
"content"
]
}
IS_MULTIMODAL
=
True
def
sample
(
self
,
...
...
@@ -815,3 +818,80 @@ class AIMODataset(HuggingFaceDataset):
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
return
sampled_requests
# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------
class
ASRDataset
(
HuggingFaceDataset
):
"""
Dataset class for processing a ASR dataset for transcription.
Tested on the following set:
+----------------+----------------------------------------+--------------------------+-----------------------------+
| Dataset | Domain | Speaking Style | hf-subset |
+----------------+----------------------------------------+--------------------------+-----------------------------+
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
| | | | release3-speaker-adaptation |
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
| AMI | Meetings | Spontaneous | ihm, sdm |
+----------------+----------------------------------------+--------------------------+-----------------------------+
"""
# noqa: E501
SUPPORTED_DATASET_PATHS
=
{
"openslr/librispeech_asr"
,
"facebook/voxpopuli"
,
"LIUM/tedlium"
,
"edinburghcstr/ami"
,
"speechcolab/gigaspeech"
,
"kensho/spgispeech"
}
DEFAULT_OUTPUT_LEN
=
128
IS_MULTIMODAL
=
True
# TODO Whisper-specific. Abstract interface when more models are supported.
TRANSCRIPTION_PREAMBLE
=
"<|startoftranscript|><|en|><|transcribe|>"
\
"<|notimestamps|>"
skip_long_audios
:
bool
=
True
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
num_requests
:
int
,
output_len
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
list
:
import
librosa
output_len
=
(
output_len
if
output_len
is
not
None
else
self
.
DEFAULT_OUTPUT_LEN
)
prompt
=
ASRDataset
.
TRANSCRIPTION_PREAMBLE
prompt_len
=
len
(
tokenizer
(
prompt
).
input_ids
)
sampled_requests
=
[]
skipped
=
0
for
item
in
self
.
data
:
if
len
(
sampled_requests
)
>=
num_requests
:
break
audio
=
item
[
"audio"
]
y
,
sr
=
audio
[
"array"
],
audio
[
"sampling_rate"
]
duration_s
=
librosa
.
get_duration
(
y
=
y
,
sr
=
sr
)
# Whisper max supported duration
if
self
.
skip_long_audios
and
duration_s
>
30
:
skipped
+=
1
continue
mm_content
=
{
"audio"
:
(
y
,
sr
)}
sampled_requests
.
append
(
SampleRequest
(
prompt
=
prompt
,
prompt_len
=
prompt_len
,
expected_output_len
=
output_len
,
multi_modal_data
=
mm_content
,
))
if
skipped
:
logger
.
warning
(
"%d samples discarded from dataset due to"
\
" their length being greater than"
\
" what Whisper supports."
,
skipped
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
return
sampled_requests
benchmarks/benchmark_serving.py
View file @
9d4ca19d
...
...
@@ -50,7 +50,7 @@ try:
except
ImportError
:
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
from
benchmark_dataset
import
(
AIMODataset
,
BurstGPTDataset
,
from
benchmark_dataset
import
(
AIMODataset
,
ASRDataset
,
BurstGPTDataset
,
ConversationDataset
,
HuggingFaceDataset
,
InstructCoderDataset
,
RandomDataset
,
SampleRequest
,
ShareGPTDataset
,
SonnetDataset
,
...
...
@@ -274,10 +274,6 @@ async def benchmark(
input_requests
[
0
].
expected_output_len
,
\
input_requests
[
0
].
multi_modal_data
if
backend
!=
"openai-chat"
and
test_mm_content
is
not
None
:
# multi-modal benchmark is only available on OpenAI Chat backend.
raise
ValueError
(
"Multi-modal content is only supported on 'openai-chat' backend."
)
assert
test_mm_content
is
None
or
isinstance
(
test_mm_content
,
dict
)
test_input
=
RequestFuncInput
(
model
=
model_id
,
...
...
@@ -604,6 +600,9 @@ def main(args: argparse.Namespace):
elif
args
.
dataset_path
in
AIMODataset
.
SUPPORTED_DATASET_PATHS
:
dataset_class
=
AIMODataset
args
.
hf_split
=
"train"
elif
args
.
dataset_path
in
ASRDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_class
=
ASRDataset
args
.
hf_split
=
"train"
else
:
supported_datasets
=
set
([
dataset_name
for
cls
in
HuggingFaceDataset
.
__subclasses__
()
...
...
@@ -615,6 +614,13 @@ def main(args: argparse.Namespace):
f
" from one of following:
{
supported_datasets
}
. "
"Please consider contributing if you would "
"like to add support for additional dataset formats."
)
if
(
dataset_class
.
IS_MULTIMODAL
and
backend
not
in
\
[
"openai-chat"
,
"openai-audio"
]):
# multi-modal benchmark is only available on OpenAI Chat backend.
raise
ValueError
(
"Multi-modal content is only supported on 'openai-chat' and "
\
"'openai-audio' backend."
)
input_requests
=
dataset_class
(
dataset_path
=
args
.
dataset_path
,
dataset_subset
=
args
.
hf_subset
,
...
...
tests/entrypoints/openai/correctness/test_transcription_api_correctness.py
View file @
9d4ca19d
...
...
@@ -150,6 +150,7 @@ def test_wer_correctness(model_name,
expected_wer
,
n_examples
=-
1
,
max_concurrent_request
=
None
):
# TODO refactor to use `ASRDataset`
with
RemoteOpenAIServer
(
model_name
,
[
'--enforce-eager'
])
as
remote_server
:
dataset
=
load_hf_dataset
(
dataset_repo
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment