"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4946c2c5003a1b71632a5c89c9372e30c820fea4"
Unverified Commit 8e6c34b3 authored by Connor Henderson's avatar Connor Henderson Committed by GitHub
Browse files

fix: Allow only test_file in pytorch and flax summarization (#22293)

allow only test_file in pytorch and flax summarization
parent 4ccaf268
......@@ -308,8 +308,13 @@ class DataTrainingArguments:
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
if (
self.dataset_name is None
and self.train_file is None
and self.validation_file is None
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training, validation, or test file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
......@@ -317,6 +322,9 @@ class DataTrainingArguments:
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
......@@ -553,10 +561,16 @@ def main():
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if training_args.do_train:
if "train" not in dataset:
raise ValueError("--do_train requires a train dataset")
column_names = dataset["train"].column_names
elif training_args.do_eval:
if "validation" not in dataset:
raise ValueError("--do_eval requires a validation dataset")
column_names = dataset["validation"].column_names
elif training_args.do_predict:
if "test" not in dataset:
raise ValueError("--do_predict requires a test dataset")
column_names = dataset["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
......@@ -620,8 +634,6 @@ def main():
return model_inputs
if training_args.do_train:
if "train" not in dataset:
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
......@@ -637,8 +649,6 @@ def main():
if training_args.do_eval:
max_target_length = data_args.val_max_target_length
if "validation" not in dataset:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
......@@ -654,8 +664,6 @@ def main():
if training_args.do_predict:
max_target_length = data_args.val_max_target_length
if "test" not in dataset:
raise ValueError("--do_predict requires a test dataset")
predict_dataset = dataset["test"]
if data_args.max_predict_samples is not None:
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
......
......@@ -262,8 +262,13 @@ class DataTrainingArguments:
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
if (
self.dataset_name is None
and self.train_file is None
and self.validation_file is None
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training, validation, or test file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
......@@ -271,6 +276,9 @@ class DataTrainingArguments:
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
......@@ -467,10 +475,16 @@ def main():
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
column_names = raw_datasets["train"].column_names
elif training_args.do_eval:
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
column_names = raw_datasets["validation"].column_names
elif training_args.do_predict:
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
column_names = raw_datasets["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
......@@ -546,8 +560,6 @@ def main():
return model_inputs
if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
......@@ -564,8 +576,6 @@ def main():
if training_args.do_eval:
max_target_length = data_args.val_max_target_length
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
......@@ -582,8 +592,6 @@ def main():
if training_args.do_predict:
max_target_length = data_args.val_max_target_length
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_dataset = raw_datasets["test"]
if data_args.max_predict_samples is not None:
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment