Unverified Commit 8df7353a authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support sgl-router parallel_batch in bench_one_batch_server (#10506)

parent ae4be601
...@@ -48,6 +48,7 @@ class BenchArgs: ...@@ -48,6 +48,7 @@ class BenchArgs:
profile_steps: int = 3 profile_steps: int = 3
profile_by_stage: bool = False profile_by_stage: bool = False
dataset_path: str = "" dataset_path: str = ""
parallel_batch: bool = False
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -90,6 +91,7 @@ class BenchArgs: ...@@ -90,6 +91,7 @@ class BenchArgs:
default=BenchArgs.dataset_path, default=BenchArgs.dataset_path,
help="Path to the dataset.", help="Path to the dataset.",
) )
parser.add_argument("--parallel-batch", action="store_true")
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -146,6 +148,7 @@ def run_one_case( ...@@ -146,6 +148,7 @@ def run_one_case(
profile_steps: int = 3, profile_steps: int = 3,
profile_by_stage: bool = False, profile_by_stage: bool = False,
dataset_path: str = "", dataset_path: str = "",
parallel_batch: bool = False,
): ):
requests.post(url + "/flush_cache") requests.post(url + "/flush_cache")
input_requests = sample_random_requests( input_requests = sample_random_requests(
...@@ -192,6 +195,7 @@ def run_one_case( ...@@ -192,6 +195,7 @@ def run_one_case(
}, },
"return_logprob": return_logprob, "return_logprob": return_logprob,
"stream": True, "stream": True,
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
}, },
stream=True, stream=True,
) )
...@@ -354,6 +358,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -354,6 +358,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
result_filename="", result_filename="",
tokenizer=tokenizer, tokenizer=tokenizer,
dataset_path=bench_args.dataset_path, dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
) )
print("=" * 8 + " Warmup End " + "=" * 8 + "\n") print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
...@@ -378,6 +383,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -378,6 +383,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
result_filename=bench_args.result_filename, result_filename=bench_args.result_filename,
tokenizer=tokenizer, tokenizer=tokenizer,
dataset_path=bench_args.dataset_path, dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
) )
) )
...@@ -404,6 +410,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -404,6 +410,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
profile_steps=bench_args.profile_steps, profile_steps=bench_args.profile_steps,
profile_by_stage=bench_args.profile_by_stage, profile_by_stage=bench_args.profile_by_stage,
dataset_path=bench_args.dataset_path, dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
)[-1], )[-1],
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment