"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1f72865726f7f8ca7d0202bb8cd2e487394f8c83"
Unverified Commit fd1d9f1a authored by Bhadresh Savani's avatar Bhadresh Savani Committed by GitHub
Browse files

[Example] Updating Question Answering examples for Predict Stage (#10792)

* added prediction stage and eval fix

* style correction

* removed extra lines
parent e8968bd0
...@@ -100,6 +100,10 @@ class DataTrainingArguments: ...@@ -100,6 +100,10 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
) )
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."},
)
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
...@@ -136,6 +140,13 @@ class DataTrainingArguments: ...@@ -136,6 +140,13 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
version_2_with_negative: bool = field( version_2_with_negative: bool = field(
default=False, metadata={"help": "If true, some of the examples do not have an answer."} default=False, metadata={"help": "If true, some of the examples do not have an answer."}
) )
...@@ -164,8 +175,13 @@ class DataTrainingArguments: ...@@ -164,8 +175,13 @@ class DataTrainingArguments:
) )
def __post_init__(self): def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None: if (
raise ValueError("Need either a dataset name or a training/validation file.") self.dataset_name is None
and self.train_file is None
and self.validation_file is None
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training/validation file/test_file.")
else: else:
if self.train_file is not None: if self.train_file is not None:
extension = self.train_file.split(".")[-1] extension = self.train_file.split(".")[-1]
...@@ -173,6 +189,9 @@ class DataTrainingArguments: ...@@ -173,6 +189,9 @@ class DataTrainingArguments:
if self.validation_file is not None: if self.validation_file is not None:
extension = self.validation_file.split(".")[-1] extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
def main(): def main():
...@@ -247,7 +266,9 @@ def main(): ...@@ -247,7 +266,9 @@ def main():
if data_args.validation_file is not None: if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1] extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files, field="data") datasets = load_dataset(extension, data_files=data_files, field="data")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html. # https://huggingface.co/docs/datasets/loading_datasets.html.
...@@ -291,8 +312,10 @@ def main(): ...@@ -291,8 +312,10 @@ def main():
# Preprocessing is slighlty different for training and evaluation. # Preprocessing is slighlty different for training and evaluation.
if training_args.do_train: if training_args.do_train:
column_names = datasets["train"].column_names column_names = datasets["train"].column_names
else: elif training_args.do_eval:
column_names = datasets["validation"].column_names column_names = datasets["validation"].column_names
else:
column_names = datasets["test"].column_names
question_column_name = "question" if "question" in column_names else column_names[0] question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1] context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2] answer_column_name = "answers" if "answers" in column_names else column_names[2]
...@@ -444,12 +467,12 @@ def main(): ...@@ -444,12 +467,12 @@ def main():
if training_args.do_eval: if training_args.do_eval:
if "validation" not in datasets: if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"] eval_examples = datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_val_samples is not None:
# We will select sample from whole data # We will select sample from whole data
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_examples = eval_examples.select(range(data_args.max_val_samples))
# Validation Feature Creation # Validation Feature Creation
eval_dataset = eval_dataset.map( eval_dataset = eval_examples.map(
prepare_validation_features, prepare_validation_features,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
...@@ -460,6 +483,25 @@ def main(): ...@@ -460,6 +483,25 @@ def main():
# During Feature creation dataset samples might increase, we will select required samples again # During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if training_args.do_predict:
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_examples = datasets["test"]
if data_args.max_test_samples is not None:
# We will select sample from whole data
test_examples = test_examples.select(range(data_args.max_test_samples))
# Test Feature Creation
test_dataset = test_examples.map(
prepare_validation_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.max_test_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(data_args.max_test_samples))
# Data collator # Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
# collator. # collator.
...@@ -470,7 +512,7 @@ def main(): ...@@ -470,7 +512,7 @@ def main():
) )
# Post-processing: # Post-processing:
def post_processing_function(examples, features, predictions): def post_processing_function(examples, features, predictions, stage="eval"):
# Post-processing: we match the start logits and end logits to answers in the original context. # Post-processing: we match the start logits and end logits to answers in the original context.
predictions = postprocess_qa_predictions( predictions = postprocess_qa_predictions(
examples=examples, examples=examples,
...@@ -482,6 +524,7 @@ def main(): ...@@ -482,6 +524,7 @@ def main():
null_score_diff_threshold=data_args.null_score_diff_threshold, null_score_diff_threshold=data_args.null_score_diff_threshold,
output_dir=training_args.output_dir, output_dir=training_args.output_dir,
is_world_process_zero=trainer.is_world_process_zero(), is_world_process_zero=trainer.is_world_process_zero(),
prefix=stage,
) )
# Format the result to the format the metric expects. # Format the result to the format the metric expects.
if data_args.version_2_with_negative: if data_args.version_2_with_negative:
...@@ -490,7 +533,8 @@ def main(): ...@@ -490,7 +533,8 @@ def main():
] ]
else: else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references) return EvalPrediction(predictions=formatted_predictions, label_ids=references)
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
...@@ -504,7 +548,7 @@ def main(): ...@@ -504,7 +548,7 @@ def main():
args=training_args, args=training_args,
train_dataset=train_dataset if training_args.do_train else None, train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
eval_examples=datasets["validation"] if training_args.do_eval else None, eval_examples=eval_examples if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
post_process_function=post_processing_function, post_process_function=post_processing_function,
...@@ -543,6 +587,18 @@ def main(): ...@@ -543,6 +587,18 @@ def main():
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
results = trainer.predict(test_dataset, test_examples)
metrics = results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
...@@ -99,6 +99,10 @@ class DataTrainingArguments: ...@@ -99,6 +99,10 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
) )
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to test the perplexity on (a text file)."},
)
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
...@@ -135,6 +139,13 @@ class DataTrainingArguments: ...@@ -135,6 +139,13 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
version_2_with_negative: bool = field( version_2_with_negative: bool = field(
default=False, metadata={"help": "If true, some of the examples do not have an answer."} default=False, metadata={"help": "If true, some of the examples do not have an answer."}
) )
...@@ -163,8 +174,13 @@ class DataTrainingArguments: ...@@ -163,8 +174,13 @@ class DataTrainingArguments:
) )
def __post_init__(self): def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None: if (
raise ValueError("Need either a dataset name or a training/validation file.") self.dataset_name is None
and self.train_file is None
and self.validation_file is None
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training/validation/test file.")
else: else:
if self.train_file is not None: if self.train_file is not None:
extension = self.train_file.split(".")[-1] extension = self.train_file.split(".")[-1]
...@@ -172,6 +188,9 @@ class DataTrainingArguments: ...@@ -172,6 +188,9 @@ class DataTrainingArguments:
if self.validation_file is not None: if self.validation_file is not None:
extension = self.validation_file.split(".")[-1] extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
def main(): def main():
...@@ -241,9 +260,13 @@ def main(): ...@@ -241,9 +260,13 @@ def main():
data_files = {} data_files = {}
if data_args.train_file is not None: if data_args.train_file is not None:
data_files["train"] = data_args.train_file data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None: if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1] extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files, field="data") datasets = load_dataset(extension, data_files=data_files, field="data")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html. # https://huggingface.co/docs/datasets/loading_datasets.html.
...@@ -278,8 +301,10 @@ def main(): ...@@ -278,8 +301,10 @@ def main():
# Preprocessing is slighlty different for training and evaluation. # Preprocessing is slighlty different for training and evaluation.
if training_args.do_train: if training_args.do_train:
column_names = datasets["train"].column_names column_names = datasets["train"].column_names
else: elif training_args.do_eval:
column_names = datasets["validation"].column_names column_names = datasets["validation"].column_names
else:
column_names = datasets["test"].column_names
question_column_name = "question" if "question" in column_names else column_names[0] question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1] context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2] answer_column_name = "answers" if "answers" in column_names else column_names[2]
...@@ -478,12 +503,12 @@ def main(): ...@@ -478,12 +503,12 @@ def main():
if training_args.do_eval: if training_args.do_eval:
if "validation" not in datasets: if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"] eval_examples = datasets["validation"]
if data_args.max_val_samples is not None: if data_args.max_val_samples is not None:
# Selecting Eval Samples from Dataset # Selecting Eval Samples from Dataset
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_examples = eval_examples.select(range(data_args.max_val_samples))
# Create Features from Eval Dataset # Create Features from Eval Dataset
eval_dataset = eval_dataset.map( eval_dataset = eval_examples.map(
prepare_validation_features, prepare_validation_features,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
...@@ -494,6 +519,25 @@ def main(): ...@@ -494,6 +519,25 @@ def main():
# Selecting Samples from Dataset again since Feature Creation might increase samples size # Selecting Samples from Dataset again since Feature Creation might increase samples size
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if training_args.do_predict:
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_examples = datasets["test"]
if data_args.max_test_samples is not None:
# We will select sample from whole data
test_examples = test_examples.select(range(data_args.max_test_samples))
# Test Feature Creation
test_dataset = test_examples.map(
prepare_validation_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.max_test_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
test_dataset = test_dataset.select(range(data_args.max_test_samples))
# Data collator # Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
# collator. # collator.
...@@ -504,7 +548,7 @@ def main(): ...@@ -504,7 +548,7 @@ def main():
) )
# Post-processing: # Post-processing:
def post_processing_function(examples, features, predictions): def post_processing_function(examples, features, predictions, stage="eval"):
# Post-processing: we match the start logits and end logits to answers in the original context. # Post-processing: we match the start logits and end logits to answers in the original context.
predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search( predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search(
examples=examples, examples=examples,
...@@ -517,6 +561,7 @@ def main(): ...@@ -517,6 +561,7 @@ def main():
end_n_top=model.config.end_n_top, end_n_top=model.config.end_n_top,
output_dir=training_args.output_dir, output_dir=training_args.output_dir,
is_world_process_zero=trainer.is_world_process_zero(), is_world_process_zero=trainer.is_world_process_zero(),
prefix=stage,
) )
# Format the result to the format the metric expects. # Format the result to the format the metric expects.
if data_args.version_2_with_negative: if data_args.version_2_with_negative:
...@@ -526,7 +571,8 @@ def main(): ...@@ -526,7 +571,8 @@ def main():
] ]
else: else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references) return EvalPrediction(predictions=formatted_predictions, label_ids=references)
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
...@@ -540,7 +586,7 @@ def main(): ...@@ -540,7 +586,7 @@ def main():
args=training_args, args=training_args,
train_dataset=train_dataset if training_args.do_train else None, train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
eval_examples=datasets["validation"] if training_args.do_eval else None, eval_examples=eval_examples if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
post_process_function=post_processing_function, post_process_function=post_processing_function,
...@@ -580,6 +626,18 @@ def main(): ...@@ -580,6 +626,18 @@ def main():
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
results = trainer.predict(test_dataset, test_examples)
metrics = results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
...@@ -98,7 +98,7 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -98,7 +98,7 @@ class QuestionAnsweringTrainer(Trainer):
if isinstance(test_dataset, datasets.Dataset): if isinstance(test_dataset, datasets.Dataset):
test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys())) test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys()))
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions) eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test")
metrics = self.compute_metrics(eval_preds) metrics = self.compute_metrics(eval_preds)
return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics) return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics)
...@@ -215,14 +215,14 @@ def postprocess_qa_predictions( ...@@ -215,14 +215,14 @@ def postprocess_qa_predictions(
assert os.path.isdir(output_dir), f"{output_dir} is not a directory." assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
prediction_file = os.path.join( prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
) )
nbest_file = os.path.join( nbest_file = os.path.join(
output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
) )
if version_2_with_negative: if version_2_with_negative:
null_odds_file = os.path.join( null_odds_file = os.path.join(
output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds_{prefix}.json"
) )
logger.info(f"Saving predictions to {prediction_file}.") logger.info(f"Saving predictions to {prediction_file}.")
...@@ -403,14 +403,14 @@ def postprocess_qa_predictions_with_beam_search( ...@@ -403,14 +403,14 @@ def postprocess_qa_predictions_with_beam_search(
assert os.path.isdir(output_dir), f"{output_dir} is not a directory." assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
prediction_file = os.path.join( prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"predictions_{prefix}".json output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
) )
nbest_file = os.path.join( nbest_file = os.path.join(
output_dir, "nbest_predictions.json" if prefix is None else f"nbest_predictions_{prefix}".json output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
) )
if version_2_with_negative: if version_2_with_negative:
null_odds_file = os.path.join( null_odds_file = os.path.join(
output_dir, "null_odds.json" if prefix is None else f"null_odds_{prefix}".json output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
) )
print(f"Saving predictions to {prediction_file}.") print(f"Saving predictions to {prediction_file}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment