Unverified Commit 7b1e0b07 authored by Talor Abramovich's avatar Talor Abramovich Committed by GitHub
Browse files

[Bugfix] Fix dataset name and path argument validation bug in vllm bench serve (#40288)


Signed-off-by: default avatartalora <talora@nvidia.com>
Signed-off-by: default avatarTalor Abramovich <talor19@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent d249a9e9
......@@ -1373,26 +1373,6 @@ class ShareGPTDataset(BenchmarkDataset):
return samples
class _ValidateDatasetArgs(argparse.Action):
"""Argparse action to validate dataset name and path compatibility."""
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values)
# Get current values of both dataset_name and dataset_path
dataset_name = getattr(namespace, "dataset_name", "random")
dataset_path = getattr(namespace, "dataset_path", None)
# Validate the combination
if dataset_name == "random" and dataset_path is not None:
parser.error(
"Cannot use 'random' dataset with --dataset-path. "
"Please specify the appropriate --dataset-name (e.g., "
"'sharegpt', 'custom', 'sonnet') for your dataset file: "
f"{dataset_path}"
)
def add_dataset_parser(parser: FlexibleArgumentParser):
parser.add_argument(
"--trust-remote-code",
......@@ -1410,7 +1390,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"--dataset-name",
type=str,
default="random",
action=_ValidateDatasetArgs,
choices=[
"sharegpt",
"burstgpt",
......@@ -1436,7 +1415,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"--dataset-path",
type=str,
default=None,
action=_ValidateDatasetArgs,
help="Path to the sharegpt/sonnet dataset or the HF dataset ID if "
"using HF dataset.",
)
......@@ -1608,7 +1586,9 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"repetition dataset.",
)
speed_bench_group = parser.add_argument_group("speed bench dataset options")
speed_bench_group = parser.add_argument_group(
"speed bench dataset options", description=SpeedBench.__doc__
)
speed_bench_group.add_argument(
"--speed-bench-dataset-subset",
type=str,
......@@ -3603,11 +3583,15 @@ class MMStarDataset(HuggingFaceDataset):
class SpeedBench(CustomDataset):
"""
Implements the SPEED-Bench dataset: https://huggingface.co/datasets/nvidia/SPEED-Bench
SPEED-Bench dataset: https://huggingface.co/datasets/nvidia/SPEED-Bench
Download the dataset using:
curl -LsSf https://raw.githubusercontent.com/NVIDIA-NeMo/Skills/refs/heads/main/nemo_skills/dataset/speed-bench/prepare.py | python3 -
`curl -LsSf https://raw.githubusercontent.com/NVIDIA-NeMo/Skills/refs/heads/main/nemo_skills/dataset/speed-bench/prepare.py | python3 -`
""" # noqa: E501
DOWNLOAD_SCRIPT_URL = "https://raw.githubusercontent.com/NVIDIA-NeMo/Skills/refs/heads/main/nemo_skills/dataset/speed-bench/prepare.py"
def __init__(self, **kwargs) -> None:
self.dataset_subset = kwargs.pop("dataset_subset", "qualitative")
self.category = kwargs.pop("category", None)
......@@ -3618,6 +3602,13 @@ class SpeedBench(CustomDataset):
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
if not Path(self.dataset_path).is_dir():
raise ValueError(
f"dataset_path {self.dataset_path} is not a directory. "
f"Please make sure to download the dataset from HuggingFace using "
f"`curl -LsSf {self.DOWNLOAD_SCRIPT_URL} | python3 -`"
)
self.data = []
# Load the JSONL file
......@@ -3628,7 +3619,11 @@ class SpeedBench(CustomDataset):
# check if the JSONL file has a 'turns' column
if "messages" not in jsonl_data.columns:
raise ValueError("JSONL file must contain a 'messages' column.")
raise ValueError(
"JSONL file must contain a 'messages' column. "
"Please make sure to download the dataset from HuggingFace using "
f"`curl -LsSf {self.DOWNLOAD_SCRIPT_URL} | python3 -`"
)
for _, row in jsonl_data.iterrows():
# sample only from a specific category if specified
......
......@@ -1711,12 +1711,25 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
trust_remote_code=args.trust_remote_code,
)
# Validate dataset name/path
if args.dataset_name is None:
raise ValueError(
"Please specify '--dataset-name' and the corresponding "
"'--dataset-path' if required."
)
if (
args.dataset_name
in ["random", "random-mm", "random-rerank", "prefix_repetition"]
and args.dataset_path is not None
):
raise ValueError(
f"Cannot use '{args.dataset_name}' dataset with --dataset-path. "
"Please specify the appropriate --dataset-name (e.g., "
"'sharegpt', 'custom', 'sonnet') for your dataset file: "
f"{args.dataset_path}"
)
# Map general --input-len and --output-len to all dataset-specific arguments
if args.input_len is not None:
args.random_input_len = args.input_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment