Unverified Commit e363e1d9 authored by Russell Klopfer's avatar Russell Klopfer Committed by GitHub
Browse files

adds metric prefix. (#12057)

* adds metric prefix.

* update tests to include prefix
parent 8994c1e4
...@@ -31,7 +31,7 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -31,7 +31,7 @@ class QuestionAnsweringTrainer(Trainer):
self.eval_examples = eval_examples self.eval_examples = eval_examples
self.post_process_function = post_process_function self.post_process_function = post_process_function
def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None): def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_examples = self.eval_examples if eval_examples is None else eval_examples eval_examples = self.eval_examples if eval_examples is None else eval_examples
...@@ -56,6 +56,11 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -56,6 +56,11 @@ class QuestionAnsweringTrainer(Trainer):
eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
metrics = self.compute_metrics(eval_preds) metrics = self.compute_metrics(eval_preds)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
self.log(metrics) self.log(metrics)
else: else:
metrics = {} metrics = {}
...@@ -67,7 +72,7 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -67,7 +72,7 @@ class QuestionAnsweringTrainer(Trainer):
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics return metrics
def predict(self, predict_dataset, predict_examples, ignore_keys=None): def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
predict_dataloader = self.get_test_dataloader(predict_dataset) predict_dataloader = self.get_test_dataloader(predict_dataset)
# Temporarily disable metric computation, we will do it in the loop here. # Temporarily disable metric computation, we will do it in the loop here.
...@@ -92,4 +97,9 @@ class QuestionAnsweringTrainer(Trainer): ...@@ -92,4 +97,9 @@ class QuestionAnsweringTrainer(Trainer):
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
metrics = self.compute_metrics(predictions) metrics = self.compute_metrics(predictions)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
...@@ -213,7 +213,7 @@ class ExamplesTests(TestCasePlus): ...@@ -213,7 +213,7 @@ class ExamplesTests(TestCasePlus):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_squad.py run_qa.py
--model_name_or_path bert-base-uncased --model_name_or_path bert-base-uncased
--version_2_with_negative --version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json --train_file tests/fixtures/tests_samples/SQUAD/sample.json
...@@ -232,8 +232,8 @@ class ExamplesTests(TestCasePlus): ...@@ -232,8 +232,8 @@ class ExamplesTests(TestCasePlus):
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_squad.main() run_squad.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["f1"], 30) self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["exact"], 30) self.assertGreaterEqual(result["eval_exact"], 30)
def test_run_swag(self): def test_run_swag(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment