Unverified Commit 7c2ec0fe authored by Yasmin Moslem's avatar Yasmin Moslem Committed by GitHub
Browse files

[Benchmarking] Add disable_shuffle option for dataset loading (#26258)


Signed-off-by: default avatarYasmin Moslem <48152713+ymoslem@users.noreply.github.com>
parent 039b6bad
......@@ -96,6 +96,8 @@ class BenchmarkDataset(ABC):
self,
dataset_path: Optional[str] = None,
random_seed: int = DEFAULT_SEED,
disable_shuffle: bool = False,
**kwargs,
) -> None:
"""
Initialize the BenchmarkDataset with an optional dataset path and random
......@@ -111,6 +113,7 @@ class BenchmarkDataset(ABC):
# Set the random seed, ensuring that a None value is replaced with the
# default seed.
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
self.disable_shuffle = disable_shuffle
self.data = None
def apply_multimodal_chat_transformation(
......@@ -1044,6 +1047,7 @@ class ShareGPTDataset(BenchmarkDataset):
if "conversations" in entry and len(entry["conversations"]) >= 2
]
random.seed(self.random_seed)
if not getattr(self, "disable_shuffle", False):
random.shuffle(self.data)
def sample(
......@@ -1175,6 +1179,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
action="store_true",
help="Skip applying chat template to prompt for datasets that support it.",
)
parser.add_argument(
"--disable-shuffle",
action="store_true",
help="Disable shuffling of dataset samples for deterministic ordering.",
)
# group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options")
......@@ -1441,7 +1450,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
args.request_id_prefix = ""
if args.dataset_name == "custom":
dataset = CustomDataset(dataset_path=args.dataset_path)
dataset = CustomDataset(
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
)
input_requests = dataset.sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
......@@ -1452,7 +1463,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
)
elif args.dataset_name == "sonnet":
dataset = SonnetDataset(dataset_path=args.dataset_path)
dataset = SonnetDataset(
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
)
# For the "sonnet" dataset, formatting depends on the backend.
if args.backend == "openai-chat":
input_requests = dataset.sample(
......@@ -1586,6 +1599,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
random_seed=args.seed,
no_stream=args.no_stream,
hf_name=args.hf_name,
disable_shuffle=args.disable_shuffle,
).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
......@@ -1600,7 +1614,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"spec_bench": lambda: SpecBench(
dataset_path=args.dataset_path, category=args.spec_bench_category
dataset_path=args.dataset_path,
category=args.spec_bench_category,
disable_shuffle=args.disable_shuffle,
).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
......@@ -1609,7 +1625,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample,
),
"sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
......@@ -1618,7 +1636,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample,
),
"burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
......@@ -1626,7 +1646,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample,
),
"random": lambda: RandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
......@@ -1639,7 +1661,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample,
),
"random-mm": lambda: RandomMultiModalDataset(
random_seed=args.seed, dataset_path=args.dataset_path
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
......@@ -1655,7 +1679,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample,
),
"prefix_repetition": lambda: PrefixRepetitionRandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
......@@ -1733,6 +1759,7 @@ class CustomDataset(BenchmarkDataset):
)
random.seed(self.random_seed)
if not getattr(self, "disable_shuffle", False):
random.shuffle(self.data)
def sample(
......@@ -1825,6 +1852,7 @@ class SpecBench(CustomDataset):
self.data.append({"prompt": prompt})
random.seed(self.random_seed)
if not getattr(self, "disable_shuffle", False):
random.shuffle(self.data)
def sample(self, **kwargs) -> list:
......@@ -2033,6 +2061,7 @@ class HuggingFaceDataset(BenchmarkDataset):
split=self.dataset_split,
streaming=self.load_stream,
)
if not getattr(self, "disable_shuffle", False):
self.data = self.data.shuffle(seed=self.random_seed)
......@@ -2849,6 +2878,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
abs(token_mismatch_total),
sign,
)
if not getattr(self, "disable_shuffle", False):
random.shuffle(requests)
return requests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment