Unverified Commit 8df7353a authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support sgl-router parallel_batch in bench_one_batch_server (#10506)

parent ae4be601
......@@ -48,6 +48,7 @@ class BenchArgs:
profile_steps: int = 3
profile_by_stage: bool = False
dataset_path: str = ""
parallel_batch: bool = False
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
......@@ -90,6 +91,7 @@ class BenchArgs:
default=BenchArgs.dataset_path,
help="Path to the dataset.",
)
parser.add_argument("--parallel-batch", action="store_true")
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......@@ -146,6 +148,7 @@ def run_one_case(
profile_steps: int = 3,
profile_by_stage: bool = False,
dataset_path: str = "",
parallel_batch: bool = False,
):
requests.post(url + "/flush_cache")
input_requests = sample_random_requests(
......@@ -192,6 +195,7 @@ def run_one_case(
},
"return_logprob": return_logprob,
"stream": True,
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
},
stream=True,
)
......@@ -354,6 +358,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
result_filename="",
tokenizer=tokenizer,
dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
)
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
......@@ -378,6 +383,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
)
)
......@@ -404,6 +410,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
profile_steps=bench_args.profile_steps,
profile_by_stage=bench_args.profile_by_stage,
dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
)[-1],
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment