Unverified Commit 6bebef60 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support accurate length control for bench serving (#6594)

parent 25be63d0
...@@ -340,7 +340,7 @@ async def async_request_sglang_generate( ...@@ -340,7 +340,7 @@ async def async_request_sglang_generate(
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = { payload = {
"text": prompt, ("text" if isinstance(prompt, str) else "input_ids"): prompt,
"sampling_params": { "sampling_params": {
"temperature": 0.0, "temperature": 0.0,
"max_new_tokens": request_func_input.output_len, "max_new_tokens": request_func_input.output_len,
...@@ -494,6 +494,7 @@ def get_tokenizer( ...@@ -494,6 +494,7 @@ def get_tokenizer(
def get_dataset(args, tokenizer): def get_dataset(args, tokenizer):
if args.dataset_name == "sharegpt": if args.dataset_name == "sharegpt":
assert not args.tokenize_prompt
input_requests = sample_sharegpt_requests( input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
...@@ -512,8 +513,10 @@ def get_dataset(args, tokenizer): ...@@ -512,8 +513,10 @@ def get_dataset(args, tokenizer):
tokenizer=tokenizer, tokenizer=tokenizer,
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
random_sample=args.dataset_name == "random", random_sample=args.dataset_name == "random",
return_text=not args.tokenize_prompt,
) )
elif args.dataset_name == "generated-shared-prefix": elif args.dataset_name == "generated-shared-prefix":
assert not args.tokenize_prompt
input_requests = sample_generated_shared_prefix_requests( input_requests = sample_generated_shared_prefix_requests(
num_groups=args.gsp_num_groups, num_groups=args.gsp_num_groups,
prompts_per_group=args.gsp_prompts_per_group, prompts_per_group=args.gsp_prompts_per_group,
...@@ -524,6 +527,7 @@ def get_dataset(args, tokenizer): ...@@ -524,6 +527,7 @@ def get_dataset(args, tokenizer):
args=args, args=args,
) )
elif args.dataset_name == "mmmu": elif args.dataset_name == "mmmu":
assert not args.tokenize_prompt
input_requests = sample_mmmu_requests( input_requests = sample_mmmu_requests(
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -1495,6 +1499,9 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1495,6 +1499,9 @@ def run_benchmark(args_: argparse.Namespace):
if not hasattr(args, "output_details"): if not hasattr(args, "output_details"):
args.output_details = False args.output_details = False
if not hasattr(args, "tokenize_prompt"):
args.tokenize_prompt = False
print(f"benchmark_args={args}") print(f"benchmark_args={args}")
# Set global environments # Set global environments
...@@ -1506,6 +1513,11 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1506,6 +1513,11 @@ def run_benchmark(args_: argparse.Namespace):
if args.extra_request_body: if args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body) extra_request_body = json.loads(args.extra_request_body)
if args.tokenize_prompt:
assert (
args.backend == "sglang"
), "`--tokenize-prompt` only compatible with `--backend sglang` currently"
# Set url # Set url
if args.port is None: if args.port is None:
args.port = { args.port = {
...@@ -1812,6 +1824,11 @@ if __name__ == "__main__": ...@@ -1812,6 +1824,11 @@ if __name__ == "__main__":
default=1, default=1,
help="Number of warmup requests to run before the benchmark", help="Number of warmup requests to run before the benchmark",
) )
parser.add_argument(
"--tokenize-prompt",
action="store_true",
help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments") group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument( group.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment