Unverified Commit 4efe2c57 authored by Lzhang-hub's avatar Lzhang-hub Committed by GitHub
Browse files

support vlm model spec bench (#10173)

parent 5be8c2f7
...@@ -16,8 +16,14 @@ from types import SimpleNamespace ...@@ -16,8 +16,14 @@ from types import SimpleNamespace
import numpy as np import numpy as np
import requests import requests
from transformers import AutoTokenizer
from sglang.bench_serving import DatasetRow, benchmark, set_global_args from sglang.bench_serving import (
DatasetRow,
benchmark,
sample_mmmu_requests,
set_global_args,
)
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
...@@ -48,20 +54,33 @@ class FakeTokenizer: ...@@ -48,20 +54,33 @@ class FakeTokenizer:
return [] return []
def send_one_batch(base_url, num_prompts, batch_size): def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
:num_prompts
]
# format: (prompt, input_len, output len). We set input_len as a dummy value 0. # format: (prompt, input_len, output len). We set input_len as a dummy value 0.
input_requests: List[DatasetRow] = [DatasetRow(p, 0, 512) for p in padded_prompts] if is_multimodal:
input_requests = sample_mmmu_requests(
num_prompts,
tokenizer,
512,
apply_chat_template=False,
)
backend = "sglang-oai-chat"
api_url = f"{base_url}/v1/chat/completions"
else:
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
:num_prompts
]
input_requests: List[DatasetRow] = [
DatasetRow(p, 0, 512) for p in padded_prompts
]
backend = "sglang"
api_url = f"{base_url}/generate"
# We need to set some dummy values in order to call `benchmark` below. # We need to set some dummy values in order to call `benchmark` below.
args = SimpleNamespace( args = SimpleNamespace(
disable_ignore_eos=False, disable_ignore_eos=False,
disable_stream=False, disable_stream=False,
return_logprob=False, return_logprob=False,
backend="sglang", backend=backend,
dataset_name="custom", dataset_name="custom",
num_prompts=None, num_prompts=None,
sharegpt_output_len=None, sharegpt_output_len=None,
...@@ -73,13 +92,12 @@ def send_one_batch(base_url, num_prompts, batch_size): ...@@ -73,13 +92,12 @@ def send_one_batch(base_url, num_prompts, batch_size):
output_details=False, output_details=False,
) )
set_global_args(args) set_global_args(args)
tokenizer = FakeTokenizer()
# Run benchmark # Run benchmark
results = asyncio.run( results = asyncio.run(
benchmark( benchmark(
backend="sglang", backend=backend,
api_url=f"{base_url}/generate", api_url=api_url,
base_url=base_url, base_url=base_url,
model_id="default", model_id="default",
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -143,8 +161,6 @@ def main(args, server_args): ...@@ -143,8 +161,6 @@ def main(args, server_args):
other_args = [] other_args = []
else: else:
other_args = [ other_args = [
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps", "--speculative-num-steps",
steps, steps,
"--speculative-eagle-topk", "--speculative-eagle-topk",
...@@ -157,6 +173,8 @@ def main(args, server_args): ...@@ -157,6 +173,8 @@ def main(args, server_args):
[ [
"--speculative-draft-model-path", "--speculative-draft-model-path",
server_args.speculative_draft_model_path, server_args.speculative_draft_model_path,
"--speculative-algorithm",
server_args.speculative_algorithm,
] ]
) )
...@@ -207,13 +225,23 @@ def main(args, server_args): ...@@ -207,13 +225,23 @@ def main(args, server_args):
}, },
) )
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=server_args.trust_remote_code
)
try: try:
# Warmup # Warmup
send_one_batch(base_url, batch_size, batch_size) send_one_batch(
base_url, batch_size, batch_size, tokenizer, args.is_multimodal
)
# Benchmark # Benchmark
acc_length, step_time, speed, completion_tokens = send_one_batch( acc_length, step_time, speed, completion_tokens = send_one_batch(
base_url, max(args.num_prompts, batch_size), batch_size base_url,
max(args.num_prompts, batch_size),
batch_size,
tokenizer,
args.is_multimodal,
) )
finally: finally:
kill_process_tree(process.pid) kill_process_tree(process.pid)
...@@ -273,6 +301,7 @@ if __name__ == "__main__": ...@@ -273,6 +301,7 @@ if __name__ == "__main__":
parser.add_argument("--start", type=int, default=0) parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int) parser.add_argument("--end", type=int)
parser.add_argument("--output", type=str, default="output.jsonl") parser.add_argument("--output", type=str, default="output.jsonl")
parser.add_argument("--is-multimodal", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
server_args: ServerArgs = ServerArgs.from_cli_args(args) server_args: ServerArgs = ServerArgs.from_cli_args(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment