Unverified Commit fd1d9f1a authored by Bhadresh Savani's avatar Bhadresh Savani Committed by GitHub
Browse files

[Example] Updating Question Answering examples for Predict Stage (#10792)

* added prediction stage and eval fix

* style correction

* removed extra lines
parent e8968bd0
......@@ -100,6 +100,10 @@ class DataTrainingArguments:
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
......@@ -136,6 +140,13 @@ class DataTrainingArguments:
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
version_2_with_negative: bool = field(
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
)
......@@ -164,8 +175,13 @@ class DataTrainingArguments:
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
if (
self.dataset_name is None
and self.train_file is None
and self.validation_file is None
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training/validation file/test_file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
......@@ -173,6 +189,9 @@ class DataTrainingArguments:
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
def main():
......@@ -247,7 +266,9 @@ def main():
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files, field="data")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
......@@ -291,8 +312,10 @@ def main():
# Preprocessing is slighlty different for training and evaluation.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
elif training_args.do_eval:
column_names = datasets["validation"].column_names
else:
column_names = datasets["test"].column_names
question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2]
......@@ -444,12 +467,12 @@ def main():
if training_args.do_eval:
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
eval_examples = datasets["validation"]
if data_args.max_val_samples is not None:
# We will select sample from whole data
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_examples = eval_examples.select(range(data_args.max_val_samples))
# Validation Feature Creation
eval_dataset = eval_dataset.map(
eval_dataset = eval_examples.map(
prepare_validation_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
......@@ -460,6 +483,25 @@ def main():
# During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if training_args.do_predict:
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_examples = datasets["test"]
if data_args.max_test_samples is not None:
# We will select sample from whole data
test_examples = test_examples.select(range(data_args.max_test_samples))
# Test Feature Creation
test_dataset = test_examples.map(
prepare_validation_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.max_test_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(data_args.max_test_samples))
# Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
# collator.
......@@ -470,7 +512,7 @@ def main():
)
# Post-processing:
def post_processing_function(examples, features, predictions):
def post_processing_function(examples, features, predictions, stage="eval"):
# Post-processing: we match the start logits and end logits to answers in the original context.
predictions = postprocess_qa_predictions(
examples=examples,
......@@ -482,6 +524,7 @@ def main():
null_score_diff_threshold=data_args.null_score_diff_threshold,
output_dir=training_args.output_dir,
is_world_process_zero=trainer.is_world_process_zero(),
prefix=stage,
)
# Format the result to the format the metric expects.
if data_args.version_2_with_negative:
......@@ -490,7 +533,8 @@ def main():
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
......@@ -504,7 +548,7 @@ def main():
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
eval_examples=datasets["validation"] if training_args.do_eval else None,
eval_examples=eval_examples if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
post_process_function=post_processing_function,
......@@ -543,6 +587,18 @@ def main():
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
results = trainer.predict(test_dataset, test_examples)
metrics = results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
def _mp_fn(index):
# For xla_spawn (TPUs)
......
......@@ -99,6 +99,10 @@ class DataTrainingArguments:
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to test the perplexity on (a text file)."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
......@@ -135,6 +139,13 @@ class DataTrainingArguments:
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
version_2_with_negative: bool = field(
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
)
......@@ -163,8 +174,13 @@ class DataTrainingArguments:
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
if (
self.dataset_name is None
and self.train_file is None
and self.validation_file is None
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training/validation/test file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
......@@ -172,6 +188,9 @@ class DataTrainingArguments:
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
def main():
......@@ -241,9 +260,13 @@ def main():
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files, field="data")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
......@@ -278,8 +301,10 @@ def main():
# Preprocessing is slighlty different for training and evaluation.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
elif training_args.do_eval:
column_names = datasets["validation"].column_names
else:
column_names = datasets["test"].column_names
question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2]
......@@ -478,12 +503,12 @@ def main():
if training_args.do_eval:
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
eval_examples = datasets["validation"]
if data_args.max_val_samples is not None:
# Selecting Eval Samples from Dataset
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_examples = eval_examples.select(range(data_args.max_val_samples))
# Create Features from Eval Dataset
eval_dataset = eval_dataset.map(
eval_dataset = eval_examples.map(
prepare_validation_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
......@@ -494,6 +519,25 @@ def main():
# Selecting Samples from Dataset again since Feature Creation might increase samples size
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if training_args.do_predict:
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_examples = datasets["test"]
if data_args.max_test_samples is not None:
# We will select sample from whole data
test_examples = test_examples.select(range(data_args.max_test_samples))
# Test Feature Creation
test_dataset = test_examples.map(
prepare_validation_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.max_test_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(data_args.max_test_samples))
# Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
# collator.
......@@ -504,7 +548,7 @@ def main():
)
# Post-processing:
def post_processing_function(examples, features, predictions):
def post_processing_function(examples, features, predictions, stage="eval"):
# Post-processing: we match the start logits and end logits to answers in the original context.
predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search(
examples=examples,
......@@ -517,6 +561,7 @@ def main():
end_n_top=model.config.end_n_top,
output_dir=training_args.output_dir,
is_world_process_zero=trainer.is_world_process_zero(),
prefix=stage,
)
# Format the result to the format the metric expects.
if data_args.version_2_with_negative:
......@@ -526,7 +571,8 @@ def main():
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
......@@ -540,7 +586,7 @@ def main():
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
eval_examples=datasets["validation"] if training_args.do_eval else None,
eval_examples=eval_examples if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
post_process_function=post_processing_function,
......@@ -580,6 +626,18 @@ def main():
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
results = trainer.predict(test_dataset, test_examples)
metrics = results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
def _mp_fn(index):
# For xla_spawn (TPUs)
......
......@@ -98,7 +98,7 @@ class QuestionAnsweringTrainer(Trainer):
if isinstance(test_dataset, datasets.Dataset):
test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys()))
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions)
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test")
metrics = self.compute_metrics(eval_preds)
return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics)
......@@ -215,14 +215,14 @@ def postprocess_qa_predictions(
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
)
nbest_file = os.path.join(
output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
)
if version_2_with_negative:
null_odds_file = os.path.join(
output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds_{prefix}.json"
)
logger.info(f"Saving predictions to {prediction_file}.")
......@@ -403,14 +403,14 @@ def postprocess_qa_predictions_with_beam_search(
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
)
nbest_file = os.path.join(
output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
)
if version_2_with_negative:
null_odds_file = os.path.join(
output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
)
print(f"Saving predictions to {prediction_file}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment