"tests/python/common/test_heterograph.py" did not exist on "1c91f460d3e534ed549bf600820d7cc31a0981ff"
Unverified Commit 4efe2c57 authored by Lzhang-hub's avatar Lzhang-hub Committed by GitHub
Browse files

support vlm model spec bench (#10173)

parent 5be8c2f7
......@@ -16,8 +16,14 @@ from types import SimpleNamespace
import numpy as np
import requests
from transformers import AutoTokenizer
from sglang.bench_serving import DatasetRow, benchmark, set_global_args
from sglang.bench_serving import (
DatasetRow,
benchmark,
sample_mmmu_requests,
set_global_args,
)
from sglang.srt.server_args import ServerArgs
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
......@@ -48,20 +54,33 @@ class FakeTokenizer:
return []
def send_one_batch(base_url, num_prompts, batch_size):
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
:num_prompts
]
def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
input_requests: List[DatasetRow] = [DatasetRow(p, 0, 512) for p in padded_prompts]
if is_multimodal:
input_requests = sample_mmmu_requests(
num_prompts,
tokenizer,
512,
apply_chat_template=False,
)
backend = "sglang-oai-chat"
api_url = f"{base_url}/v1/chat/completions"
else:
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
:num_prompts
]
input_requests: List[DatasetRow] = [
DatasetRow(p, 0, 512) for p in padded_prompts
]
backend = "sglang"
api_url = f"{base_url}/generate"
# We need to set some dummy values in order to call `benchmark` below.
args = SimpleNamespace(
disable_ignore_eos=False,
disable_stream=False,
return_logprob=False,
backend="sglang",
backend=backend,
dataset_name="custom",
num_prompts=None,
sharegpt_output_len=None,
......@@ -73,13 +92,12 @@ def send_one_batch(base_url, num_prompts, batch_size):
output_details=False,
)
set_global_args(args)
tokenizer = FakeTokenizer()
# Run benchmark
results = asyncio.run(
benchmark(
backend="sglang",
api_url=f"{base_url}/generate",
backend=backend,
api_url=api_url,
base_url=base_url,
model_id="default",
tokenizer=tokenizer,
......@@ -143,8 +161,6 @@ def main(args, server_args):
other_args = []
else:
other_args = [
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
steps,
"--speculative-eagle-topk",
......@@ -157,6 +173,8 @@ def main(args, server_args):
[
"--speculative-draft-model-path",
server_args.speculative_draft_model_path,
"--speculative-algorithm",
server_args.speculative_algorithm,
]
)
......@@ -207,13 +225,23 @@ def main(args, server_args):
},
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=server_args.trust_remote_code
)
try:
# Warmup
send_one_batch(base_url, batch_size, batch_size)
send_one_batch(
base_url, batch_size, batch_size, tokenizer, args.is_multimodal
)
# Benchmark
acc_length, step_time, speed, completion_tokens = send_one_batch(
base_url, max(args.num_prompts, batch_size), batch_size
base_url,
max(args.num_prompts, batch_size),
batch_size,
tokenizer,
args.is_multimodal,
)
finally:
kill_process_tree(process.pid)
......@@ -273,6 +301,7 @@ if __name__ == "__main__":
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int)
parser.add_argument("--output", type=str, default="output.jsonl")
parser.add_argument("--is-multimodal", action="store_true", default=False)
args = parser.parse_args()
server_args: ServerArgs = ServerArgs.from_cli_args(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment