Unverified Commit 3b44bbee authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Allow passing extra request body to bench_offline_throughput.py (#2085)

parent 80e2c4a8
...@@ -23,7 +23,7 @@ import json ...@@ -23,7 +23,7 @@ import json
import logging import logging
import random import random
import time import time
from typing import List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
...@@ -55,6 +55,7 @@ class BenchArgs: ...@@ -55,6 +55,7 @@ class BenchArgs:
gen_question_len: int = 128 gen_question_len: int = 128
gen_output_len: int = 256 gen_output_len: int = 256
disable_ignore_eos: bool = False disable_ignore_eos: bool = False
extra_request_body: Optional[str] = None
seed: int = 1 seed: int = 1
do_not_exit: bool = False do_not_exit: bool = False
...@@ -143,6 +144,13 @@ class BenchArgs: ...@@ -143,6 +144,13 @@ class BenchArgs:
default=BenchArgs.disable_ignore_eos, default=BenchArgs.disable_ignore_eos,
help="Disable ignore EOS token", help="Disable ignore EOS token",
) )
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
type=str,
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument( parser.add_argument(
"--do-not-exit", "--do-not-exit",
...@@ -161,6 +169,7 @@ def throughput_test_once( ...@@ -161,6 +169,7 @@ def throughput_test_once(
backend, backend,
reqs: List[Tuple[str, int, int]], reqs: List[Tuple[str, int, int]],
ignore_eos: bool, ignore_eos: bool,
extra_request_body: Dict,
): ):
measurement_results = { measurement_results = {
"backend": backend_name, "backend": backend_name,
...@@ -180,6 +189,7 @@ def throughput_test_once( ...@@ -180,6 +189,7 @@ def throughput_test_once(
"temperature": 0, "temperature": 0,
"max_new_tokens": r[2], "max_new_tokens": r[2],
"ignore_eos": ignore_eos, "ignore_eos": ignore_eos,
**extra_request_body,
} }
for r in reqs for r in reqs
] ]
...@@ -233,6 +243,11 @@ def throughput_test( ...@@ -233,6 +243,11 @@ def throughput_test(
random.seed(bench_args.seed) random.seed(bench_args.seed)
np.random.seed(bench_args.seed) np.random.seed(bench_args.seed)
# Parse args
extra_request_body = {}
if bench_args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
# Read dataset # Read dataset
input_requests = get_dataset(bench_args, tokenizer) input_requests = get_dataset(bench_args, tokenizer)
...@@ -252,6 +267,7 @@ def throughput_test( ...@@ -252,6 +267,7 @@ def throughput_test(
backend=backend, backend=backend,
reqs=warmup_requests, reqs=warmup_requests,
ignore_eos=not bench_args.disable_ignore_eos, ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
) )
logging.info("\nBenchmark...") logging.info("\nBenchmark...")
...@@ -260,6 +276,7 @@ def throughput_test( ...@@ -260,6 +276,7 @@ def throughput_test(
backend=backend, backend=backend,
reqs=input_requests, reqs=input_requests,
ignore_eos=not bench_args.disable_ignore_eos, ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
) )
if bench_args.result_filename: if bench_args.result_filename:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment