Unverified Commit 8e6c34b3 authored by Connor Henderson's avatar Connor Henderson Committed by GitHub
Browse files

fix: Allow only test_file in pytorch and flax summarization (#22293)

allow only test_file in pytorch and flax summarization
parent 4ccaf268
...@@ -308,8 +308,13 @@ class DataTrainingArguments: ...@@ -308,8 +308,13 @@ class DataTrainingArguments:
) )
def __post_init__(self): def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None: if (
raise ValueError("Need either a dataset name or a training/validation file.") self.dataset_name is None
and self.train_file is None
and self.validation_file is None
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training, validation, or test file.")
else: else:
if self.train_file is not None: if self.train_file is not None:
extension = self.train_file.split(".")[-1] extension = self.train_file.split(".")[-1]
...@@ -317,6 +322,9 @@ class DataTrainingArguments: ...@@ -317,6 +322,9 @@ class DataTrainingArguments:
if self.validation_file is not None: if self.validation_file is not None:
extension = self.validation_file.split(".")[-1] extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
if self.val_max_target_length is None: if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length self.val_max_target_length = self.max_target_length
...@@ -553,10 +561,16 @@ def main(): ...@@ -553,10 +561,16 @@ def main():
# Preprocessing the datasets. # Preprocessing the datasets.
# We need to tokenize inputs and targets. # We need to tokenize inputs and targets.
if training_args.do_train: if training_args.do_train:
if "train" not in dataset:
raise ValueError("--do_train requires a train dataset")
column_names = dataset["train"].column_names column_names = dataset["train"].column_names
elif training_args.do_eval: elif training_args.do_eval:
if "validation" not in dataset:
raise ValueError("--do_eval requires a validation dataset")
column_names = dataset["validation"].column_names column_names = dataset["validation"].column_names
elif training_args.do_predict: elif training_args.do_predict:
if "test" not in dataset:
raise ValueError("--do_predict requires a test dataset")
column_names = dataset["test"].column_names column_names = dataset["test"].column_names
else: else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
...@@ -620,8 +634,6 @@ def main(): ...@@ -620,8 +634,6 @@ def main():
return model_inputs return model_inputs
if training_args.do_train: if training_args.do_train:
if "train" not in dataset:
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"] train_dataset = dataset["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples) max_train_samples = min(len(train_dataset), data_args.max_train_samples)
...@@ -637,8 +649,6 @@ def main(): ...@@ -637,8 +649,6 @@ def main():
if training_args.do_eval: if training_args.do_eval:
max_target_length = data_args.val_max_target_length max_target_length = data_args.val_max_target_length
if "validation" not in dataset:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = dataset["validation"] eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
...@@ -654,8 +664,6 @@ def main(): ...@@ -654,8 +664,6 @@ def main():
if training_args.do_predict: if training_args.do_predict:
max_target_length = data_args.val_max_target_length max_target_length = data_args.val_max_target_length
if "test" not in dataset:
raise ValueError("--do_predict requires a test dataset")
predict_dataset = dataset["test"] predict_dataset = dataset["test"]
if data_args.max_predict_samples is not None: if data_args.max_predict_samples is not None:
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
......
...@@ -262,8 +262,13 @@ class DataTrainingArguments: ...@@ -262,8 +262,13 @@ class DataTrainingArguments:
) )
def __post_init__(self): def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None: if (
raise ValueError("Need either a dataset name or a training/validation file.") self.dataset_name is None
and self.train_file is None
and self.validation_file is None
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training, validation, or test file.")
else: else:
if self.train_file is not None: if self.train_file is not None:
extension = self.train_file.split(".")[-1] extension = self.train_file.split(".")[-1]
...@@ -271,6 +276,9 @@ class DataTrainingArguments: ...@@ -271,6 +276,9 @@ class DataTrainingArguments:
if self.validation_file is not None: if self.validation_file is not None:
extension = self.validation_file.split(".")[-1] extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
if self.val_max_target_length is None: if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length self.val_max_target_length = self.max_target_length
...@@ -467,10 +475,16 @@ def main(): ...@@ -467,10 +475,16 @@ def main():
# Preprocessing the datasets. # Preprocessing the datasets.
# We need to tokenize inputs and targets. # We need to tokenize inputs and targets.
if training_args.do_train: if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
column_names = raw_datasets["train"].column_names column_names = raw_datasets["train"].column_names
elif training_args.do_eval: elif training_args.do_eval:
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
column_names = raw_datasets["validation"].column_names column_names = raw_datasets["validation"].column_names
elif training_args.do_predict: elif training_args.do_predict:
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
column_names = raw_datasets["test"].column_names column_names = raw_datasets["test"].column_names
else: else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
...@@ -546,8 +560,6 @@ def main(): ...@@ -546,8 +560,6 @@ def main():
return model_inputs return model_inputs
if training_args.do_train: if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"] train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples) max_train_samples = min(len(train_dataset), data_args.max_train_samples)
...@@ -564,8 +576,6 @@ def main(): ...@@ -564,8 +576,6 @@ def main():
if training_args.do_eval: if training_args.do_eval:
max_target_length = data_args.val_max_target_length max_target_length = data_args.val_max_target_length
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"] eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
...@@ -582,8 +592,6 @@ def main(): ...@@ -582,8 +592,6 @@ def main():
if training_args.do_predict: if training_args.do_predict:
max_target_length = data_args.val_max_target_length max_target_length = data_args.val_max_target_length
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_dataset = raw_datasets["test"] predict_dataset = raw_datasets["test"]
if data_args.max_predict_samples is not None: if data_args.max_predict_samples is not None:
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment