Unverified Commit 111b1379 authored by miter's avatar miter Committed by GitHub
Browse files

add dataset_path for bench_one_batch_server.py (#10113)


Signed-off-by: default avatarlinhuang <linhuang@ruijie.com.cn>
Co-authored-by: default avatarlinhuang <linhuang@ruijie.com.cn>
parent 41628dc1
...@@ -47,6 +47,7 @@ class BenchArgs: ...@@ -47,6 +47,7 @@ class BenchArgs:
profile: bool = False profile: bool = False
profile_steps: int = 3 profile_steps: int = 3
profile_by_stage: bool = False profile_by_stage: bool = False
dataset_path: str = ""
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -83,6 +84,9 @@ class BenchArgs: ...@@ -83,6 +84,9 @@ class BenchArgs:
"--profile-steps", type=int, default=BenchArgs.profile_steps "--profile-steps", type=int, default=BenchArgs.profile_steps
) )
parser.add_argument("--profile-by-stage", action="store_true") parser.add_argument("--profile-by-stage", action="store_true")
parser.add_argument(
"--dataset-path", type=str, default=BenchArgs.dataset_path, help="Path to the dataset."
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -138,6 +142,7 @@ def run_one_case( ...@@ -138,6 +142,7 @@ def run_one_case(
profile: bool = False, profile: bool = False,
profile_steps: int = 3, profile_steps: int = 3,
profile_by_stage: bool = False, profile_by_stage: bool = False,
dataset_path: str = "",
): ):
requests.post(url + "/flush_cache") requests.post(url + "/flush_cache")
input_requests = sample_random_requests( input_requests = sample_random_requests(
...@@ -146,7 +151,7 @@ def run_one_case( ...@@ -146,7 +151,7 @@ def run_one_case(
num_prompts=batch_size, num_prompts=batch_size,
range_ratio=1.0, range_ratio=1.0,
tokenizer=tokenizer, tokenizer=tokenizer,
dataset_path="", dataset_path=dataset_path,
random_sample=True, random_sample=True,
return_text=False, return_text=False,
) )
...@@ -345,6 +350,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -345,6 +350,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
run_name="", run_name="",
result_filename="", result_filename="",
tokenizer=tokenizer, tokenizer=tokenizer,
dataset_path=bench_args.dataset_path
) )
print("=" * 8 + " Warmup End " + "=" * 8 + "\n") print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment