Unverified Commit 0b24af4d authored by Zhao Chen's avatar Zhao Chen Committed by GitHub
Browse files

test: support return logprobs in bench_offline_throughput test (#12462)


Signed-off-by: default avatarZhao Chen <zhaochen.zju@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent a209fb05
...@@ -60,6 +60,8 @@ class BenchArgs: ...@@ -60,6 +60,8 @@ class BenchArgs:
skip_warmup: bool = False skip_warmup: bool = False
do_not_exit: bool = False do_not_exit: bool = False
prompt_suffix: str = "" prompt_suffix: str = ""
return_logprob: bool = False
logprob_start_len: int = -1
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -187,6 +189,17 @@ class BenchArgs: ...@@ -187,6 +189,17 @@ class BenchArgs:
default="", default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.", help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
) )
parser.add_argument(
"--return-logprob",
action="store_true",
help="Enable returning log probabilities.",
)
parser.add_argument(
"--logprob-start-len",
type=int,
default=-1,
help="Start length for logprob. -1 means only return logprobs for output tokens (default). 0 means return logprobs for all tokens including input.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -201,6 +214,8 @@ def throughput_test_once( ...@@ -201,6 +214,8 @@ def throughput_test_once(
ignore_eos: bool, ignore_eos: bool,
extra_request_body: Dict, extra_request_body: Dict,
profile: bool, profile: bool,
return_logprob: bool = False,
logprob_start_len: int = -1,
): ):
measurement_results = { measurement_results = {
"backend": backend_name, "backend": backend_name,
...@@ -233,7 +248,12 @@ def throughput_test_once( ...@@ -233,7 +248,12 @@ def throughput_test_once(
backend.start_profile() backend.start_profile()
st = time.perf_counter() st = time.perf_counter()
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params) gen_out = backend.generate(
prompt=prompt,
sampling_params=sampling_params,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
)
latency = time.perf_counter() - st latency = time.perf_counter() - st
if profile: if profile:
...@@ -355,6 +375,8 @@ def throughput_test( ...@@ -355,6 +375,8 @@ def throughput_test(
ignore_eos=not bench_args.disable_ignore_eos, ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=False, profile=False,
return_logprob=bench_args.return_logprob,
logprob_start_len=bench_args.logprob_start_len,
) )
time.sleep(0.5) time.sleep(0.5)
...@@ -366,6 +388,8 @@ def throughput_test( ...@@ -366,6 +388,8 @@ def throughput_test(
ignore_eos=not bench_args.disable_ignore_eos, ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=bench_args.profile, profile=bench_args.profile,
return_logprob=bench_args.return_logprob,
logprob_start_len=bench_args.logprob_start_len,
) )
backend.shutdown() backend.shutdown()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment