Unverified Commit 975a5ec6 authored by Kay Yan's avatar Kay Yan Committed by GitHub
Browse files

[fix] update bench_speculative.py for compatibility (#7764)


Signed-off-by: default avatarKay Yan <kay.yan@daocloud.io>
parent 1e3e3add
...@@ -17,7 +17,7 @@ from types import SimpleNamespace ...@@ -17,7 +17,7 @@ from types import SimpleNamespace
import numpy as np import numpy as np
import requests import requests
from sglang.bench_serving import benchmark, set_global_args from sglang.bench_serving import DatasetRow, benchmark, set_global_args
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
...@@ -54,7 +54,7 @@ def send_one_batch(base_url, num_prompts, batch_size): ...@@ -54,7 +54,7 @@ def send_one_batch(base_url, num_prompts, batch_size):
] ]
# format: (prompt, input_len, output len). We set input_len as a dummy value 0. # format: (prompt, input_len, output len). We set input_len as a dummy value 0.
input_requests = [(p, 0, 512) for p in padded_prompts] input_requests: List[DatasetRow] = [DatasetRow(p, 0, 512) for p in padded_prompts]
# We need to set some dummy values in order to call `benchmark` below. # We need to set some dummy values in order to call `benchmark` below.
args = SimpleNamespace( args = SimpleNamespace(
...@@ -69,6 +69,8 @@ def send_one_batch(base_url, num_prompts, batch_size): ...@@ -69,6 +69,8 @@ def send_one_batch(base_url, num_prompts, batch_size):
random_output_len=None, random_output_len=None,
random_range_ratio=None, random_range_ratio=None,
output_file=None, output_file=None,
warmup_requests=1,
output_details=False,
) )
set_global_args(args) set_global_args(args)
tokenizer = FakeTokenizer() tokenizer = FakeTokenizer()
...@@ -97,7 +99,9 @@ def send_one_batch(base_url, num_prompts, batch_size): ...@@ -97,7 +99,9 @@ def send_one_batch(base_url, num_prompts, batch_size):
server_info = requests.get(base_url + "/get_server_info").json() server_info = requests.get(base_url + "/get_server_info").json()
# We use 20% percentile instead of median on purpose # We use 20% percentile instead of median on purpose
step_time = np.percentile(server_info["step_time_dict"][str(batch_size)], 20) step_time = np.percentile(
server_info["internal_states"][0]["step_time_dict"][str(batch_size)], 20
)
speed = 1 / step_time * acc_length speed = 1 / step_time * acc_length
return ( return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment