Unverified Commit 55bd97f3 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

minor: add dataset dump and questions shuffle (#2093)

parent e57c3e12
...@@ -15,6 +15,7 @@ import argparse ...@@ -15,6 +15,7 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import pickle
import random import random
import resource import resource
import sys import sys
...@@ -682,6 +683,11 @@ def sample_generated_shared_prefix_requests( ...@@ -682,6 +683,11 @@ def sample_generated_shared_prefix_requests(
output_len: int, output_len: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
if args.generated_input_path and os.path.exists(args.generated_input_path):
print(f"\nloading generated input data from {args.generated_input_path}")
with open(args.generated_input_path, "rb") as f:
return pickle.load(f)
"""Generate benchmark requests with shared system prompts using random tokens.""" """Generate benchmark requests with shared system prompts using random tokens."""
# Generate system prompts for each group # Generate system prompts for each group
system_prompts = [] system_prompts = []
...@@ -695,6 +701,9 @@ def sample_generated_shared_prefix_requests( ...@@ -695,6 +701,9 @@ def sample_generated_shared_prefix_requests(
question = gen_prompt(tokenizer, question_len) question = gen_prompt(tokenizer, question_len)
questions.append(question) questions.append(question)
# Shuffle questions
random.shuffle(questions)
# Combine system prompts with questions # Combine system prompts with questions
input_requests = [] input_requests = []
total_input_tokens = 0 total_input_tokens = 0
...@@ -723,6 +732,11 @@ def sample_generated_shared_prefix_requests( ...@@ -723,6 +732,11 @@ def sample_generated_shared_prefix_requests(
print( print(
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n" f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
) )
if args.generated_input_save_path:
print(f"Saving generated input data to {args.generated_input_save_path}")
os.makedirs(os.path.dirname(args.generated_input_save_path), exist_ok=True)
with open(args.generated_input_save_path, "wb") as f:
pickle.dump(input_requests, f)
return input_requests return input_requests
...@@ -1331,6 +1345,16 @@ if __name__ == "__main__": ...@@ -1331,6 +1345,16 @@ if __name__ == "__main__":
default=256, default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset", help="Target length in tokens for outputs in generated-shared-prefix dataset",
) )
parser.add_argument(
"--generated-input-save-path",
type=str,
help="Path to save generated input data",
)
parser.add_argument(
"--generated-input-path",
type=str,
help="Path to load previously generated input data",
)
args = parser.parse_args() args = parser.parse_args()
run_benchmark(args) run_benchmark(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment