Unverified Commit 3dfc6023 authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

Fix bench_serving with random-ids (#5214)

parent 15e91d72
......@@ -490,7 +490,7 @@ def get_dataset(args, tokenizer):
prompt_suffix=args.prompt_suffix,
apply_chat_template=args.apply_chat_template,
)
elif args.dataset_name == "random":
elif args.dataset_name.startswith("random"):
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
......@@ -498,6 +498,7 @@ def get_dataset(args, tokenizer):
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
dataset_path=args.dataset_path,
random_sample=args.dataset_name == "random",
)
elif args.dataset_name == "generated-shared-prefix":
input_requests = sample_generated_shared_prefix_requests(
......@@ -687,6 +688,7 @@ def sample_random_requests(
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
random_sample: bool = True,
) -> List[Tuple[str, int, int]]:
input_lens = np.random.randint(
......@@ -700,11 +702,15 @@ def sample_random_requests(
size=num_prompts,
)
if True:
if random_sample:
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
print(
"If you do not want to randomly sample from a dataset,"
" please use --dataset-name random-ids."
)
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
......@@ -1223,7 +1229,7 @@ async def benchmark(
output_file_name = args.output_file
else:
now = datetime.now().strftime("%m%d")
if args.dataset_name == "random":
if args.dataset_name.startswith("random"):
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
else:
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
......@@ -1442,7 +1448,7 @@ if __name__ == "__main__":
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "random", "generated-shared-prefix"],
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment