Unverified Commit be292b7c authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] Fix pooling model benchmark script (#36300)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 77a73458
......@@ -795,6 +795,17 @@ ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
"vllm-rerank": async_request_vllm_rerank,
}
POOLING_BACKENDS = {
"openai-embeddings",
"openai-embeddings-chat",
"openai-embeddings-clip",
"openai-embeddings-vlm2vec",
"infinity-embeddings",
"infinity-embeddings-clip",
"vllm-pooling",
"vllm-rerank",
}
OPENAI_COMPATIBLE_BACKENDS = [
k
for k, v in ASYNC_REQUEST_FUNCS.items()
......
......@@ -45,6 +45,7 @@ from vllm.benchmarks.datasets import SampleRequest, add_dataset_parser, get_samp
from vllm.benchmarks.lib.endpoint_request_func import (
ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS,
POOLING_BACKENDS,
RequestFuncInput,
RequestFuncOutput,
)
......@@ -1721,11 +1722,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
goodput_config_dict = check_goodput_args(args)
backend = args.backend
task_type = (
TaskType.POOLING
if "embeddings" in backend or "rerank" in backend
else TaskType.GENERATION
)
task_type = TaskType.POOLING if backend in POOLING_BACKENDS else TaskType.GENERATION
# Collect the sampling parameters.
if task_type == TaskType.GENERATION:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment